Skip to content

Commit

Permalink
Add mock test for FileHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Jan 24, 2025
1 parent 44f8b59 commit 6036a27
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 99 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ set(SOURCES
"src/cufile/config.cpp"
"src/cufile/driver.cpp"
"src/defaults.cpp"
"src/detail/file_handle_dep.cpp"
"src/error.cpp"
"src/file_handle.cpp"
"src/posix_io.cpp"
Expand Down
83 changes: 83 additions & 0 deletions cpp/include/kvikio/detail/file_handle_dep.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <sys/stat.h>
#include <sys/types.h>

#include <cstddef>
#include <cstdlib>
#include <stdexcept>
#include <system_error>
#include <utility>

#include <kvikio/buffer.hpp>
#include <kvikio/cufile/config.hpp>
#include <kvikio/defaults.hpp>
#include <kvikio/error.hpp>
#include <kvikio/parallel_operation.hpp>
#include <kvikio/posix_io.hpp>
#include <kvikio/shim/cufile.hpp>
#include <kvikio/stream.hpp>
#include <kvikio/utils.hpp>

namespace kvikio {
class FileHandle;

namespace detail {
class FileHandleDependencyBase {
public:
virtual ~FileHandleDependencyBase() = default;
void set_file_handle(FileHandle* file_handle);

virtual int open_fd(const std::string& file_path,
const std::string& flags,
bool o_direct,
mode_t mode) = 0;
virtual void close_fd(int fd) = 0;
virtual CUfileError_t cuFile_handle_register(CUfileHandle_t* fh, CUfileDescr_t* descr) = 0;
virtual void cuFile_handle_deregister(CUfileHandle_t fh) = 0;
virtual bool is_compat_mode_preferred() const noexcept = 0;
virtual bool is_compat_mode_preferred_for_async() const noexcept = 0;
virtual bool is_compat_mode_preferred_for_async(CompatMode requested_compat_mode) = 0;

protected:
FileHandle* _file_handle;
};

class FileHandleDependencyProduction : public FileHandleDependencyBase {
public:
/**
* @brief Open file using `open(2)`
*
* @param flags Open flags given as a string
* @param o_direct Append O_DIRECT to `flags`
* @param mode Access modes
* @return File descriptor
*/
virtual int open_fd(const std::string& file_path,
const std::string& flags,
bool o_direct,
mode_t mode) override;
virtual void close_fd(int fd) override;
virtual CUfileError_t cuFile_handle_register(CUfileHandle_t* fh, CUfileDescr_t* descr) override;
virtual void cuFile_handle_deregister(CUfileHandle_t fh) override;
virtual bool is_compat_mode_preferred() const noexcept override;
virtual bool is_compat_mode_preferred_for_async() const noexcept override;
virtual bool is_compat_mode_preferred_for_async(CompatMode requested_compat_mode) override;
};
} // namespace detail
} // namespace kvikio
8 changes: 7 additions & 1 deletion cpp/include/kvikio/file_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <kvikio/buffer.hpp>
#include <kvikio/cufile/config.hpp>
#include <kvikio/defaults.hpp>
#include <kvikio/detail/file_handle_dep.hpp>
#include <kvikio/error.hpp>
#include <kvikio/parallel_operation.hpp>
#include <kvikio/posix_io.hpp>
Expand All @@ -43,6 +44,7 @@ namespace kvikio {
*/
class FileHandle {
private:
std::unique_ptr<detail::FileHandleDependencyBase> _dep;
// We use two file descriptors, one opened with the O_DIRECT flag and one without.
int _fd_direct_on{-1};
int _fd_direct_off{-1};
Expand All @@ -51,6 +53,8 @@ class FileHandle {
mutable std::size_t _nbytes{0}; // The size of the underlying file, zero means unknown.
CUfileHandle_t _handle{};

friend class detail::FileHandleDependencyProduction;

/**
* @brief Given a requested compatibility mode, whether it is expected to reduce to `ON` for
* asynchronous I/O.
Expand Down Expand Up @@ -85,7 +89,9 @@ class FileHandle {
FileHandle(const std::string& file_path,
const std::string& flags = "r",
mode_t mode = m644,
CompatMode compat_mode = defaults::compat_mode());
CompatMode compat_mode = defaults::compat_mode(),
std::unique_ptr<detail::FileHandleDependencyBase> dep =
std::make_unique<detail::FileHandleDependencyProduction>());

/**
* @brief FileHandle support move semantic but isn't copyable
Expand Down
126 changes: 126 additions & 0 deletions cpp/src/detail/file_handle_dep.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <kvikio/detail/file_handle_dep.hpp>
#include <kvikio/file_handle.hpp>

namespace {

/**
* @brief Parse open file flags given as a string and return oflags
*
* @param flags The flags
* @param o_direct Append O_DIRECT to the open flags
* @return oflags
*
* @throw std::invalid_argument if the specified flags are not supported.
* @throw std::invalid_argument if `o_direct` is true, but `O_DIRECT` is not supported.
*/
int open_fd_parse_flags(const std::string& flags, bool o_direct)
{
int file_flags = -1;
if (flags.empty()) { throw std::invalid_argument("Unknown file open flag"); }
switch (flags[0]) {
case 'r':
file_flags = O_RDONLY;
if (flags[1] == '+') { file_flags = O_RDWR; }
break;
case 'w':
file_flags = O_WRONLY;
if (flags[1] == '+') { file_flags = O_RDWR; }
file_flags |= O_CREAT | O_TRUNC;
break;
case 'a': throw std::invalid_argument("Open flag 'a' isn't supported");
default: throw std::invalid_argument("Unknown file open flag");
}
file_flags |= O_CLOEXEC;
if (o_direct) {
#if defined(O_DIRECT)
file_flags |= O_DIRECT;
#else
throw std::invalid_argument("'o_direct' flag unsupported on this platform");
#endif
}
return file_flags;
}

} // namespace

namespace kvikio {

namespace detail {
void FileHandleDependencyBase::set_file_handle(FileHandle* file_handle)
{
_file_handle = file_handle;
}

bool FileHandleDependencyProduction::is_compat_mode_preferred() const noexcept
{
return defaults::is_compat_mode_preferred(_file_handle->_compat_mode);
}

bool FileHandleDependencyProduction::is_compat_mode_preferred_for_async() const noexcept
{
static bool is_extra_symbol_available = is_stream_api_available();
static bool is_config_path_empty = config_path().empty();
return is_compat_mode_preferred() || !is_extra_symbol_available || is_config_path_empty;
}

bool FileHandleDependencyProduction::is_compat_mode_preferred_for_async(
CompatMode requested_compat_mode)
{
if (defaults::is_compat_mode_preferred(requested_compat_mode)) { return true; }

if (!is_stream_api_available()) {
if (requested_compat_mode == CompatMode::AUTO) { return true; }
throw std::runtime_error("Missing the cuFile stream api.");
}

// When checking for availability, we also check if cuFile's config file exists. This is
// because even when the stream API is available, it doesn't work if no config file exists.
if (config_path().empty()) {
if (requested_compat_mode == CompatMode::AUTO) { return true; }
throw std::runtime_error("Missing cuFile configuration file.");
}
return false;
}

int FileHandleDependencyProduction::open_fd(const std::string& file_path,
const std::string& flags,
bool o_direct,
mode_t mode)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg)
int fd = ::open(file_path.c_str(), open_fd_parse_flags(flags, o_direct), mode);
if (fd == -1) { throw std::system_error(errno, std::generic_category(), "Unable to open file"); }
return fd;
}

void FileHandleDependencyProduction::close_fd(int fd) { ::close(fd); }

CUfileError_t FileHandleDependencyProduction::cuFile_handle_register(CUfileHandle_t* fh,
CUfileDescr_t* descr)
{
return cuFileAPI::instance().HandleRegister(fh, descr);
}

void FileHandleDependencyProduction::cuFile_handle_deregister(CUfileHandle_t fh)
{
cuFileAPI::instance().HandleDeregister(fh);
}

} // namespace detail
} // namespace kvikio
Loading

0 comments on commit 6036a27

Please sign in to comment.