diff --git a/velox/connectors/hive/HiveDataSource.cpp b/velox/connectors/hive/HiveDataSource.cpp index c654afae24219..0378bba9d5153 100644 --- a/velox/connectors/hive/HiveDataSource.cpp +++ b/velox/connectors/hive/HiveDataSource.cpp @@ -185,6 +185,7 @@ void HiveDataSource::addSplit(std::shared_ptr split) { split_, "Previous split has not been processed yet. Call next to process the split."); split_ = std::dynamic_pointer_cast(split); + facebook::velox::core::sparkInputFileName = split_->getFileName(); VELOX_CHECK_NOT_NULL(split_, "Wrong type of split"); VLOG(1) << "Adding split " << split_->toString(); diff --git a/velox/core/QueryConfig.cpp b/velox/core/QueryConfig.cpp index de251815c4c36..ceb46b2005576 100644 --- a/velox/core/QueryConfig.cpp +++ b/velox/core/QueryConfig.cpp @@ -19,7 +19,7 @@ #include "velox/core/QueryConfig.h" namespace facebook::velox::core { - +thread_local std::string sparkInputFileName = ""; double toBytesPerCapacityUnit(CapacityUnit unit) { switch (unit) { case CapacityUnit::BYTE: diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index eebc359c00775..6808e5c3ff4df 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -27,6 +27,8 @@ enum class CapacityUnit { PETABYTE }; +extern thread_local std::string sparkInputFileName; + double toBytesPerCapacityUnit(CapacityUnit unit); CapacityUnit valueOfCapacityUnit(const std::string& unitStr); @@ -298,6 +300,9 @@ class QueryConfig { /// The current spark partition id. static constexpr const char* kSparkPartitionId = "spark.partition_id"; + /// The file name of the current Spark task. + static constexpr const char* kInputFileName = "input_file_name"; + /// The number of local parallel table writer operators per task. static constexpr const char* kTaskWriterCount = "task_writer_count"; diff --git a/velox/functions/sparksql/InputFileName.h b/velox/functions/sparksql/InputFileName.h new file mode 100644 index 0000000000000..374a6219a04e5 --- /dev/null +++ b/velox/functions/sparksql/InputFileName.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/QueryConfig.h" +#include "velox/functions/Macros.h" + +namespace facebook::velox::functions::sparksql { +template +struct InputFileNameFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(out_type& result) { + result = facebook::velox::core::sparkInputFileName; + } +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index a9edb617acda9..d9b403fbc437d 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -31,6 +31,7 @@ #include "velox/functions/sparksql/DateTimeFunctions.h" #include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/In.h" +#include "velox/functions/sparksql/InputFileName.h" #include "velox/functions/sparksql/LeastGreatest.h" #include "velox/functions/sparksql/MightContain.h" #include "velox/functions/sparksql/MonotonicallyIncreasingId.h" @@ -375,6 +376,8 @@ void registerFunctions(const std::string& prefix) { {prefix + "monotonically_increasing_id"}); registerFunction>({prefix + "uuid"}); + registerFunction( + {prefix + "input_file_name"}); registerArrayUnionFunctions(prefix); registerArrayUnionFunctions(prefix); registerArrayUnionFunctions(prefix); diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 614be32439ed5..779e8234193b4 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -47,6 +47,7 @@ add_executable( StringToMapTest.cpp UnscaledValueFunctionTest.cpp UuidTest.cpp + InputFileNameTest.cpp XxHash64Test.cpp) add_test(velox_functions_spark_test velox_functions_spark_test) diff --git a/velox/functions/sparksql/tests/InputFileNameTest.cpp b/velox/functions/sparksql/tests/InputFileNameTest.cpp new file mode 100644 index 0000000000000..28cd7d9ad0baf --- /dev/null +++ b/velox/functions/sparksql/tests/InputFileNameTest.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class InputFileNameTest : public SparkFunctionBaseTest { + protected: + void testInputFileName(std::string fileName) { + facebook::velox::core::sparkInputFileName = fileName; + EXPECT_EQ( + evaluateOnce( + "input_file_name()", makeRowVector(ROW({}), 1)), + std::string(fileName)); + } +}; + +TEST_F(InputFileNameTest, basic) { + testInputFileName("text.txt"); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test