Skip to content

Commit

Permalink
[SPARK-46355][SQL][TESTS][FOLLOWUP] Test to check number of open files
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Added an unittest to ensure that the XML parser doesn't have unbounded number of open files.

### Why are the changes needed?
To test the [fix](apache#44287) of too many open files.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unittest

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45074 from sandip-db/xml_open_test.

Authored-by: Sandip Agarwala <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
sandip-db authored and dongjoon-hyun committed Feb 9, 2024
1 parent 3f5faaa commit d179f75
Showing 1 changed file with 91 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.file.{Files, Path, Paths}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDateTime}
import java.util.TimeZone
import java.util.concurrent.ConcurrentHashMap
import javax.xml.stream.XMLStreamException

import scala.collection.immutable.ArraySeq
Expand All @@ -31,10 +32,11 @@ import scala.jdk.CollectionConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataInputStream
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.io.compress.GzipCodec

import org.apache.spark.{SparkException, SparkFileNotFoundException}
import org.apache.spark.{DebugFilesystem, SparkException, SparkFileNotFoundException}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoders, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.xml.XmlOptions
Expand Down Expand Up @@ -2881,4 +2883,92 @@ class XmlSuite
checkValidation("field name with space", "Illegal name character ' '")
checkValidation("field", "", false)
}

test("SPARK-46355: Check Number of open files") {
withSQLConf("fs.file.impl" -> classOf[XmlSuiteDebugFileSystem].getName,
"fs.file.impl.disable.cache" -> "true") {
withTempDir { dir =>
val path = dir.getCanonicalPath
val numFiles = 10
val numRecords = 10000
val data =
spark.sparkContext.parallelize(
(0 until numRecords).map(i => Row(i.toLong, (i * 2).toLong)))
val schema = buildSchema(field("a1", LongType), field("a2", LongType))
val df = spark.createDataFrame(data, schema)

// Write numFiles files
df.repartition(numFiles)
.write
.mode(SaveMode.Overwrite)
.option("rowTag", "row")
.xml(path)

// Serialized file read for Schema inference
val dfRead = spark.read
.option("rowTag", "row")
.xml(path)
assert(XmlSuiteDebugFileSystem.totalFiles() === numFiles)
assert(XmlSuiteDebugFileSystem.maxFiles() > 1)

XmlSuiteDebugFileSystem.reset()
// Serialized file read for parsing across multiple executors
assert(dfRead.count() === numRecords)
assert(XmlSuiteDebugFileSystem.totalFiles() === numFiles)
assert(XmlSuiteDebugFileSystem.maxFiles() > 1)
}
}
}
}

// Mock file system that checks the number of open files
class XmlSuiteDebugFileSystem extends DebugFilesystem {

override def open(f: org.apache.hadoop.fs.Path, bufferSize: Int): FSDataInputStream = {
val wrapped: FSDataInputStream = super.open(f, bufferSize)
// All files should be closed before reading next one
XmlSuiteDebugFileSystem.open()
new FSDataInputStream(wrapped.getWrappedStream) {
override def close(): Unit = {
try {
wrapped.close()
} finally {
XmlSuiteDebugFileSystem.close()
}
}
}
}
}

object XmlSuiteDebugFileSystem {
private val openCounterPerThread = new ConcurrentHashMap[Long, Int]()
private val maxFilesPerThread = new ConcurrentHashMap[Long, Int]()

def reset() : Unit = {
maxFilesPerThread.clear()
}

def open() : Unit = {
val threadId = Thread.currentThread().getId
// assert that there are no open files for this executor
assert(openCounterPerThread.getOrDefault(threadId, 0) == 0)
openCounterPerThread.put(threadId, 1)
maxFilesPerThread.put(threadId, maxFilesPerThread.getOrDefault(threadId, 0) + 1)
}

def close(): Unit = {
val threadId = Thread.currentThread().getId
if (openCounterPerThread.get(threadId) == 1) {
openCounterPerThread.put(threadId, 0)
}
assert(openCounterPerThread.get(threadId) == 0)
}

def maxFiles() : Int = {
maxFilesPerThread.values().asScala.max
}

def totalFiles() : Int = {
maxFilesPerThread.values().asScala.sum
}
}

0 comments on commit d179f75

Please sign in to comment.