Skip to content

Commit

Permalink
spark input_file_name function support
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyangxiaozhu committed Mar 24, 2024
1 parent bc607cb commit 37b2f79
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 1 deletion.
1 change: 1 addition & 0 deletions velox/connectors/hive/HiveDataSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ void HiveDataSource::addSplit(std::shared_ptr<ConnectorSplit> split) {
split_,
"Previous split has not been processed yet. Call next to process the split.");
split_ = std::dynamic_pointer_cast<HiveConnectorSplit>(split);
facebook::velox::core::sparkInputFileName = split_->getFileName();
VELOX_CHECK_NOT_NULL(split_, "Wrong type of split");

VLOG(1) << "Adding split " << split_->toString();
Expand Down
2 changes: 1 addition & 1 deletion velox/core/QueryConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "velox/core/QueryConfig.h"

namespace facebook::velox::core {

thread_local std::string sparkInputFileName = "";
double toBytesPerCapacityUnit(CapacityUnit unit) {
switch (unit) {
case CapacityUnit::BYTE:
Expand Down
5 changes: 5 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ enum class CapacityUnit {
PETABYTE
};

extern thread_local std::string sparkInputFileName;

double toBytesPerCapacityUnit(CapacityUnit unit);

CapacityUnit valueOfCapacityUnit(const std::string& unitStr);
Expand Down Expand Up @@ -298,6 +300,9 @@ class QueryConfig {
/// The current spark partition id.
static constexpr const char* kSparkPartitionId = "spark.partition_id";

/// The file name of the current Spark task.
static constexpr const char* kInputFileName = "input_file_name";

/// The number of local parallel table writer operators per task.
static constexpr const char* kTaskWriterCount = "task_writer_count";

Expand Down
30 changes: 30 additions & 0 deletions velox/functions/sparksql/InputFileName.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/core/QueryConfig.h"
#include "velox/functions/Macros.h"

namespace facebook::velox::functions::sparksql {
template <typename T>
struct InputFileNameFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(out_type<Varchar>& result) {
result = facebook::velox::core::sparkInputFileName;
}
};
} // namespace facebook::velox::functions::sparksql
3 changes: 3 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "velox/functions/sparksql/DateTimeFunctions.h"
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/InputFileName.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/MonotonicallyIncreasingId.h"
Expand Down Expand Up @@ -375,6 +376,8 @@ void registerFunctions(const std::string& prefix) {
{prefix + "monotonically_increasing_id"});

registerFunction<UuidFunction, Varchar, Constant<int64_t>>({prefix + "uuid"});
registerFunction<InputFileNameFunction, Varchar>(
{prefix + "input_file_name"});
registerArrayUnionFunctions<bool>(prefix);
registerArrayUnionFunctions<int8_t>(prefix);
registerArrayUnionFunctions<int16_t>(prefix);
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_executable(
StringToMapTest.cpp
UnscaledValueFunctionTest.cpp
UuidTest.cpp
InputFileNameTest.cpp
XxHash64Test.cpp)

add_test(velox_functions_spark_test velox_functions_spark_test)
Expand Down
38 changes: 38 additions & 0 deletions velox/functions/sparksql/tests/InputFileNameTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

namespace facebook::velox::functions::sparksql::test {
namespace {

class InputFileNameTest : public SparkFunctionBaseTest {
protected:
void testInputFileName(std::string fileName) {
facebook::velox::core::sparkInputFileName = fileName;
EXPECT_EQ(
evaluateOnce<std::string>(
"input_file_name()", makeRowVector(ROW({}), 1)),
std::string(fileName));
}
};

TEST_F(InputFileNameTest, basic) {
testInputFileName("text.txt");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 37b2f79

Please sign in to comment.