Skip to content

Commit

Permalink
[SPARK-43117][CONNECT] Make ProtoUtils.abbreviate support repeated …
Browse files Browse the repository at this point in the history
…fields

### What changes were proposed in this pull request?
Make `ProtoUtils.abbreviate` support repeated fields

### Why are the changes needed?
existing implementation does not work for repeated fields (strings/messages)

we don't have `repeated bytes` in Spark Connect for now, so let it alone

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
added UTs

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#45056 from zhengruifeng/proto_abbr_repeat.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 8, 2024
1 parent 729fc8e commit 71b76dc
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,17 @@ private[connect] object ProtoUtils {
val size = string.length
val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
if (size > threshold) {
builder.setField(field, createString(string.take(threshold), size))
builder.setField(field, truncateString(string, threshold))
}

case (field: FieldDescriptor, strings: java.lang.Iterable[_])
if field.getJavaType == FieldDescriptor.JavaType.STRING && field.isRepeated
&& strings != null =>
val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
strings.iterator().asScala.zipWithIndex.foreach {
case (string: String, i) if string != null && string.length > threshold =>
builder.setRepeatedField(field, i, truncateString(string, threshold))
case _ =>
}

case (field: FieldDescriptor, byteString: ByteString)
Expand All @@ -69,23 +79,33 @@ private[connect] object ProtoUtils {
.concat(createTruncatedByteString(size)))
}

// TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx, msg>
// TODO(SPARK-46988): should support map<xxx, msg>
case (field: FieldDescriptor, msg: Message)
if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null =>
if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && !field.isRepeated
&& msg != null =>
builder.setField(field, abbreviate(msg, thresholds))

case (field: FieldDescriptor, msgs: java.lang.Iterable[_])
if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && field.isRepeated
&& msgs != null =>
msgs.iterator().asScala.zipWithIndex.foreach {
case (msg: Message, i) if msg != null =>
builder.setRepeatedField(field, i, abbreviate(msg, thresholds))
case _ =>
}

case _ =>
}

builder.build()
}

private def createTruncatedByteString(size: Int): ByteString = {
ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
private def truncateString(string: String, threshold: Int): String = {
s"${string.take(threshold)}[truncated(size=${format.format(string.length)})]"
}

private def createString(prefix: String, size: Int): String = {
s"$prefix[truncated(size=${format.format(size)})]"
private def createTruncatedByteString(size: Int): ByteString = {
ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
}

// Because Spark Connect operation tags are also set as SparkContext Job tags, they cannot contain
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,77 @@ class AbbreviateSuite extends SparkFunSuite {
}
}

test("truncate repeated strings") {
val sql = proto.Relation
.newBuilder()
.setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T"))
.build()
val names = Seq.range(0, 10).map(i => i.toString * 1024)
val drop = proto.Drop.newBuilder().setInput(sql).addAllColumnNames(names.asJava).build()

Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
val truncated = ProtoUtils.abbreviate(drop, threshold)
assert(drop.isInstanceOf[proto.Drop])

val truncatedNames = truncated.asInstanceOf[proto.Drop].getColumnNamesList.asScala.toSeq
assert(truncatedNames.length === 10)

if (threshold < 1024) {
truncatedNames.foreach { truncatedName =>
assert(truncatedName.indexOf("[truncated") === threshold)
}
} else {
truncatedNames.foreach { truncatedName =>
assert(truncatedName.indexOf("[truncated") === -1)
assert(truncatedName.length === 1024)
}
}

}
}

test("truncate repeated messages") {
val sql = proto.Relation
.newBuilder()
.setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM T"))
.build()

val cols = Seq.range(0, 10).map { i =>
proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute
.newBuilder()
.setUnparsedIdentifier(i.toString * 1024)
.build())
.build()
}
val drop = proto.Drop.newBuilder().setInput(sql).addAllColumns(cols.asJava).build()

Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
val truncated = ProtoUtils.abbreviate(drop, threshold)
assert(drop.isInstanceOf[proto.Drop])

val truncatedCols = truncated.asInstanceOf[proto.Drop].getColumnsList.asScala.toSeq
assert(truncatedCols.length === 10)

if (threshold < 1024) {
truncatedCols.foreach { truncatedCol =>
assert(truncatedCol.isInstanceOf[proto.Expression])
val truncatedName = truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier
assert(truncatedName.indexOf("[truncated") === threshold)
}
} else {
truncatedCols.foreach { truncatedCol =>
assert(truncatedCol.isInstanceOf[proto.Expression])
val truncatedName = truncatedCol.getUnresolvedAttribute.getUnparsedIdentifier
assert(truncatedName.indexOf("[truncated") === -1)
assert(truncatedName.length === 1024)
}
}
}
}

test("truncate bytes: simple python udf") {
Seq(1, 8, 16, 64, 256).foreach { numBytes =>
val bytes = Array.ofDim[Byte](numBytes)
Expand Down

0 comments on commit 71b76dc

Please sign in to comment.