Skip to content

Commit

Permalink
kvstore/s3 Use conditional write operations.
Browse files Browse the repository at this point in the history
AWS has added conditional write support for S3. Using conditional writes improves write atomicity in tensorstore, with some caveats:

1/ Not all S3 compatible object stores support if-match; tensorstore will not issue conditional writes except on aws unless the variable TENSORSTORE_S3_USE_CONDITIONAL_WRITE is set.

2/ DELETE on AWS is not atomic, even when conditional writes are supported, as DELETE only supports if-match for directory buckets, so at present the if-match header is not used.

Relevant API docs:
https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html
https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObject.html

Fixes: #211
PiperOrigin-RevId: 721962621
Change-Id: Ia2831aabba645686de98e4f95103a00ae0b30498
  • Loading branch information
laramiel authored and copybara-github committed Feb 1, 2025
1 parent 61a6cee commit 0ee12fe
Show file tree
Hide file tree
Showing 20 changed files with 820 additions and 422 deletions.
4 changes: 1 addition & 3 deletions tensorstore/internal/http/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ tensorstore_cc_test(
],
deps = [
":http",
"//tensorstore/util:status_testutil",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
Expand Down Expand Up @@ -258,11 +257,10 @@ tensorstore_cc_library(
deps = [
":http",
"//tensorstore/util:result",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)
8 changes: 6 additions & 2 deletions tensorstore/internal/http/http_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <limits>
#include <optional>
#include <string>
#include <utility>

#include "absl/status/status.h"
Expand Down Expand Up @@ -168,6 +167,12 @@ absl::StatusCode HttpResponseCodeToStatusCode(const HttpResponse& response) {
// body.)
return absl::StatusCode::kOutOfRange;

case 409: // Conflict, such as a concurrent request.
return absl::StatusCode::kAborted;

case 501: // Not Implemented
return absl::StatusCode::kUnimplemented;

// UNAVAILABLE indicates a problem that can go away if the request
// is just retried without any modification. 308 return codes are intended
// for write requests that can be retried. See the documentation and the
Expand All @@ -177,7 +182,6 @@ absl::StatusCode HttpResponseCodeToStatusCode(const HttpResponse& response) {
// https://cloud.google.com/storage/docs/request-rate
case 308: // Resume Incomplete
case 408: // Request Timeout
case 409: // Conflict
case 429: // Too Many Requests
case 500: // Internal Server Error
case 502: // Bad Gateway
Expand Down
24 changes: 13 additions & 11 deletions tensorstore/internal/http/http_response_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,12 @@

#include "tensorstore/internal/http/http_response.h"

#include <set>
#include <utility>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "tensorstore/util/status_testutil.h"

namespace {

using ::tensorstore::IsOkAndHolds;
using ::tensorstore::internal_http::HttpResponse;


TEST(HttpResponseCodeToStatusTest, AllCodes) {
using ::tensorstore::internal_http::HttpResponseCodeToStatus;

Expand Down Expand Up @@ -68,12 +59,24 @@ TEST(HttpResponseCodeToStatusTest, AllCodes) {
HttpResponseCodeToStatus({code, {}, {}}).code())
<< code;
}
for (auto code : {308, 408, 409, 429, 500, 502, 503, 504}) {
for (auto code : {308, 408, 429, 500, 502, 503, 504}) {
seen.insert(code);
EXPECT_EQ(absl::StatusCode::kUnavailable,
HttpResponseCodeToStatus({code, {}, {}}).code())
<< code;
}
for (auto code : {409}) {
seen.insert(code);
EXPECT_EQ(absl::StatusCode::kAborted,
HttpResponseCodeToStatus({code, {}, {}}).code())
<< code;
}
for (auto code : {501}) {
seen.insert(code);
EXPECT_EQ(absl::StatusCode::kUnimplemented,
HttpResponseCodeToStatus({code, {}, {}}).code())
<< code;
}

for (int i = 300; i < 600; i++) {
if (seen.count(i) > 0) continue;
Expand All @@ -84,5 +87,4 @@ TEST(HttpResponseCodeToStatusTest, AllCodes) {
}
}


} // namespace
25 changes: 15 additions & 10 deletions tensorstore/internal/http/mock_http_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -70,10 +71,8 @@ void ApplyResponseToHandler(const Result<HttpResponse>& response,
}
}

void DefaultMockHttpTransport::Reset(
absl::flat_hash_map<std::string, internal_http::HttpResponse>
url_to_response,
bool add_headers) {
void DefaultMockHttpTransport::Reset(Responses url_to_response,
bool add_headers) {
if (add_headers) {
// Add additional headers to the response.
for (auto& kv : url_to_response) {
Expand All @@ -90,15 +89,21 @@ void DefaultMockHttpTransport::IssueRequestWithHandler(
const HttpRequest& request, IssueRequestOptions options,
HttpResponseHandler* response_handler) {
std::string key = absl::StrCat(request.method, " ", request.url);
ABSL_LOG(INFO) << key;
absl::MutexLock l(&mutex_);
requests_.push_back(request);
if (auto it =
url_to_response_.find(absl::StrCat(request.method, " ", request.url));
it != url_to_response_.end()) {
return ApplyResponseToHandler(it->second, response_handler);

for (auto& kv : url_to_response_) {
if (!kv.first.empty() && kv.first == key) {
ApplyResponseToHandler(kv.second, response_handler);
kv.first.clear();
return;
}
}

ABSL_LOG(INFO) << "Returning 404 for: " << request;
return ApplyResponseToHandler(
internal_http::HttpResponse{404, absl::Cord(), {}}, response_handler);
internal_http::HttpResponse{404, absl::Cord(key), {}}, response_handler);
}

} // namespace internal_http
Expand Down
29 changes: 16 additions & 13 deletions tensorstore/internal/http/mock_http_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "tensorstore/internal/http/http_request.h"
#include "tensorstore/internal/http/http_response.h"
#include "tensorstore/internal/http/http_transport.h"
Expand All @@ -45,20 +42,27 @@ void ApplyResponseToHandler(const Result<HttpResponse>& response,

/// Mocks an HttpTransport by overriding the IssueRequest method to
/// respond with a predefined set of request-response pairs supplied
/// to the constructor
/// to the constructor.
/// The first matching pair will be returned for each call, then expired.
class DefaultMockHttpTransport : public internal_http::HttpTransport {
public:
DefaultMockHttpTransport(
absl::flat_hash_map<std::string, internal_http::HttpResponse>
url_to_response,
bool add_headers = true) {
using Responses =
std::vector<std::pair<std::string, internal_http::HttpResponse>>;

/// Construct a DefaultMockHttpTransport that returns 404 for all requests.
DefaultMockHttpTransport() = default;

explicit DefaultMockHttpTransport(Responses url_to_response) {
Reset(std::move(url_to_response), true);
}
DefaultMockHttpTransport(Responses url_to_response, bool add_headers) {
Reset(std::move(url_to_response), add_headers);
}
virtual ~DefaultMockHttpTransport() = default;

void Reset(absl::flat_hash_map<std::string, internal_http::HttpResponse>
url_to_response,
bool add_headers = true);
/// Initializes the list of request-response pairs.
/// The first matching pair will be returned for each call, then expired.
void Reset(Responses url_to_response, bool add_headers = true);

const std::vector<HttpRequest>& requests() const { return requests_; }

Expand All @@ -69,8 +73,7 @@ class DefaultMockHttpTransport : public internal_http::HttpTransport {
private:
absl::Mutex mutex_;
std::vector<HttpRequest> requests_;
absl::flat_hash_map<std::string, internal_http::HttpResponse>
url_to_response_;
Responses url_to_response_;
};

} // namespace internal_http
Expand Down
28 changes: 25 additions & 3 deletions tensorstore/kvstore/s3/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tensorstore_cc_library(
":s3_request_builder",
":s3_resource",
":s3_uri_utils",
":use_conditional_write",
":validate",
"//tensorstore:context",
"//tensorstore/internal:data_copy_concurrency_resource",
Expand Down Expand Up @@ -57,6 +58,7 @@ tensorstore_cc_library(
"//tensorstore/util/execution:any_receiver",
"//tensorstore/util/garbage_collection",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -140,6 +142,7 @@ tensorstore_cc_test(
":s3",
":s3_metadata",
"//tensorstore/internal/http",
"//tensorstore/kvstore:generation",
"//tensorstore/util:status_testutil",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:cord",
Expand Down Expand Up @@ -274,9 +277,7 @@ tensorstore_cc_test(
":s3_endpoint",
"//tensorstore/internal/http",
"//tensorstore/internal/http:mock_http_transport",
"//tensorstore/util:future",
"//tensorstore/util:status_testutil",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:cord",
"@com_google_googletest//:gtest_main",
Expand Down Expand Up @@ -312,7 +313,6 @@ tensorstore_cc_test(
"//tensorstore/kvstore:test_util",
"//tensorstore/util:future",
"//tensorstore/util:status_testutil",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
Expand All @@ -338,6 +338,7 @@ tensorstore_cc_test(
deps = [
":s3",
":s3_request_builder",
":use_conditional_write",
"//tensorstore:context",
"//tensorstore:json_serialization_options_base",
"//tensorstore/internal:env",
Expand Down Expand Up @@ -376,3 +377,24 @@ tensorstore_cc_test(
"@com_google_googletest//:gtest_main",
],
)

tensorstore_cc_library(
name = "use_conditional_write",
srcs = ["use_conditional_write.cc"],
hdrs = ["use_conditional_write.h"],
deps = [
"//tensorstore/internal:env",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_re2//:re2",
],
)

cc_test(
name = "use_conditional_write_test",
srcs = ["use_conditional_write_test.cc"],
deps = [
":use_conditional_write",
"@com_google_googletest//:gtest_main",
],
)
2 changes: 0 additions & 2 deletions tensorstore/kvstore/s3/credentials/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ tensorstore_cc_library(
hdrs = ["test_utils.h"],
deps = [
"//tensorstore/internal/http",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
Expand Down Expand Up @@ -169,7 +168,6 @@ tensorstore_cc_test(
"//tensorstore/internal/http:mock_http_transport",
"//tensorstore/util:result",
"//tensorstore/util:status_testutil",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/time",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ class DefaultCredentialProviderTest : public ::testing::Test {
};

TEST_F(DefaultCredentialProviderTest, AnonymousCredentials) {
auto mock_transport = std::make_shared<DefaultMockHttpTransport>(
absl::flat_hash_map<std::string, HttpResponse>());
auto mock_transport = std::make_shared<DefaultMockHttpTransport>();
auto provider = std::make_unique<DefaultAwsCredentialsProvider>(
Options{{}, {}, {}, mock_transport});

Expand Down Expand Up @@ -156,7 +155,7 @@ TEST_F(DefaultCredentialProviderTest, ConfigureEC2ProviderFromOptions) {
EXPECT_EQ(credentials.expires_at, expiry - absl::Seconds(60));

/// Force failure on credential retrieval
mock_transport->Reset(absl::flat_hash_map<std::string, HttpResponse>{
mock_transport->Reset({
{"POST http://endpoint/latest/api/token",
HttpResponse{404, absl::Cord{""}}},
});
Expand All @@ -182,7 +181,7 @@ TEST_F(DefaultCredentialProviderTest, ConfigureEC2ProviderFromOptions) {
EXPECT_EQ(credentials.expires_at, expiry - absl::Seconds(60));

/// Force failure on credential retrieval
mock_transport->Reset(absl::flat_hash_map<std::string, HttpResponse>{
mock_transport->Reset({
{"POST http://endpoint/latest/api/token",
HttpResponse{404, absl::Cord{""}}},
});
Expand Down
Loading

0 comments on commit 0ee12fe

Please sign in to comment.