From 4f2eb5dce4c2a115e7bbc4733597e3841aa5710d Mon Sep 17 00:00:00 2001 From: "Jonathan M. Henson" Date: Thu, 26 Mar 2020 20:20:31 -0700 Subject: [PATCH] Update to use stream activiation refactor in aws-c-http. (#104) * Update to use stream activiation refactor in aws-c-http. * Updated default args for allocators to use the configured one. Updated STL allocator to allow PMR style allocators where it makes sense. Updated http to use the new stream activate() apis. Added test for stream that isn't activated. * Updated to latest http version, updated api contract and documentation to reflect. * Update to latest version of all crt libs. --- .github/workflows/ci.yml | 2 +- aws-common-runtime/aws-c-auth | 2 +- aws-common-runtime/aws-c-common | 2 +- aws-common-runtime/aws-c-compression | 2 +- aws-common-runtime/aws-c-http | 2 +- aws-common-runtime/aws-c-io | 2 +- include/aws/crt/StlAllocator.h | 22 ++-- include/aws/crt/auth/Credentials.h | 22 ++-- include/aws/crt/auth/Sigv4Signing.h | 4 +- include/aws/crt/crypto/Hash.h | 4 +- include/aws/crt/http/HttpConnection.h | 30 ++++- include/aws/crt/http/HttpConnectionManager.h | 4 +- include/aws/crt/http/HttpRequestResponse.h | 4 +- include/aws/crt/io/Bootstrap.h | 2 +- include/aws/crt/io/EventLoopGroup.h | 4 +- include/aws/crt/io/HostResolver.h | 2 +- include/aws/crt/io/Stream.h | 4 +- include/aws/crt/io/TlsOptions.h | 13 +- include/aws/crt/io/Uri.h | 4 +- include/aws/crt/mqtt/MqttClient.h | 2 +- include/aws/iot/MqttClient.h | 14 +-- source/crypto/HMAC.cpp | 5 +- source/http/HttpConnection.cpp | 71 ++++++----- tests/CMakeLists.txt | 1 + tests/HttpClientTest.cpp | 119 ++++++++++++++++++- 25 files changed, 246 insertions(+), 97 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b04deb853..b14a8b35c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: - '!master' env: - BUILDER_VERSION: v0.5.7 + BUILDER_VERSION: v0.5.9 BUILDER_HOST: https://d19elf31gohf1l.cloudfront.net PACKAGE_NAME: aws-crt-cpp LINUX_BASE_IMAGE: ubuntu-16-x64 diff --git a/aws-common-runtime/aws-c-auth b/aws-common-runtime/aws-c-auth index 543e3764d..479187b39 160000 --- a/aws-common-runtime/aws-c-auth +++ b/aws-common-runtime/aws-c-auth @@ -1 +1 @@ -Subproject commit 543e3764d7fa4f1228ed630d1aa40e4a0de396d6 +Subproject commit 479187b39ddc7ae8b67d782b444ceb64fc2ff1c8 diff --git a/aws-common-runtime/aws-c-common b/aws-common-runtime/aws-c-common index d023c9cb1..6f0787f0c 160000 --- a/aws-common-runtime/aws-c-common +++ b/aws-common-runtime/aws-c-common @@ -1 +1 @@ -Subproject commit d023c9cb10e22bace150550a7357eab36164af52 +Subproject commit 6f0787f0c567326a5313188af7cb3b4dbfef1a24 diff --git a/aws-common-runtime/aws-c-compression b/aws-common-runtime/aws-c-compression index c7e477bf6..4b32a9b8b 160000 --- a/aws-common-runtime/aws-c-compression +++ b/aws-common-runtime/aws-c-compression @@ -1 +1 @@ -Subproject commit c7e477bf6ab7df17cdad223300541fe3aa978f35 +Subproject commit 4b32a9b8bdf07698fd080bade472a5e135abb70a diff --git a/aws-common-runtime/aws-c-http b/aws-common-runtime/aws-c-http index e5386127c..f5096cab2 160000 --- a/aws-common-runtime/aws-c-http +++ b/aws-common-runtime/aws-c-http @@ -1 +1 @@ -Subproject commit e5386127cd43b5aea476500b9a38d7e9b72b82f7 +Subproject commit f5096cab294d201bc4a217ed5679f912baa14684 diff --git a/aws-common-runtime/aws-c-io b/aws-common-runtime/aws-c-io index 3b09cdcee..c91d18348 160000 --- a/aws-common-runtime/aws-c-io +++ b/aws-common-runtime/aws-c-io @@ -1 +1 @@ -Subproject commit 3b09cdceed9db43b2ac8642083417051de8d9041 +Subproject commit c91d18348e01345eec969b67697401c6ca67d7b9 diff --git a/include/aws/crt/StlAllocator.h b/include/aws/crt/StlAllocator.h index 0f1f1ba00..9b2e56780 100644 --- a/include/aws/crt/StlAllocator.h +++ b/include/aws/crt/StlAllocator.h @@ -30,10 +30,16 @@ namespace Aws public: using Base = std::allocator; - StlAllocator() noexcept : Base() {} - StlAllocator(const StlAllocator &a) noexcept : Base(a) {} + StlAllocator() noexcept : Base() { m_allocator = g_allocator; } - template StlAllocator(const StlAllocator &a) noexcept : Base(a) {} + StlAllocator(Allocator *allocator) noexcept : Base() { m_allocator = allocator; } + + StlAllocator(const StlAllocator &a) noexcept : Base(a) { m_allocator = a.m_allocator; } + + template StlAllocator(const StlAllocator &a) noexcept : Base(a) + { + m_allocator = a.m_allocator; + } ~StlAllocator() {} @@ -47,15 +53,17 @@ namespace Aws typename Base::pointer allocate(size_type n, const void *hint = nullptr) { (void)hint; - AWS_ASSERT(g_allocator); - return reinterpret_cast(aws_mem_acquire(g_allocator, n * sizeof(T))); + AWS_ASSERT(m_allocator); + return reinterpret_cast(aws_mem_acquire(m_allocator, n * sizeof(T))); } void deallocate(typename Base::pointer p, size_type) { - AWS_ASSERT(g_allocator); - aws_mem_release(g_allocator, p); + AWS_ASSERT(m_allocator); + aws_mem_release(m_allocator, p); } + + Allocator *m_allocator; }; } // namespace Crt } // namespace Aws diff --git a/include/aws/crt/auth/Credentials.h b/include/aws/crt/auth/Credentials.h index b637d18f7..bf63266fd 100644 --- a/include/aws/crt/auth/Credentials.h +++ b/include/aws/crt/auth/Credentials.h @@ -40,12 +40,12 @@ namespace Aws class AWS_CRT_CPP_API Credentials { public: - Credentials(aws_credentials *credentials, Allocator *allocator = DefaultAllocator()) noexcept; + Credentials(aws_credentials *credentials, Allocator *allocator = g_allocator) noexcept; Credentials( ByteCursor access_key_id, ByteCursor secret_access_key, ByteCursor session_token, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; ~Credentials(); @@ -241,9 +241,7 @@ namespace Aws class AWS_CRT_CPP_API CredentialsProvider : public ICredentialsProvider { public: - CredentialsProvider( - aws_credentials_provider *provider, - Allocator *allocator = DefaultAllocator()) noexcept; + CredentialsProvider(aws_credentials_provider *provider, Allocator *allocator = g_allocator) noexcept; virtual ~CredentialsProvider(); @@ -278,27 +276,27 @@ namespace Aws */ static std::shared_ptr CreateCredentialsProviderStatic( const CredentialsProviderStaticConfig &config, - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); /** * Creates a provider that returns credentials sourced from environment variables */ static std::shared_ptr CreateCredentialsProviderEnvironment( - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); /** * Creates a provider that returns credentials sourced from config files */ static std::shared_ptr CreateCredentialsProviderProfile( const CredentialsProviderProfileConfig &config, - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); /** * Creates a provider that returns credentials sourced from Ec2 instance metadata service */ static std::shared_ptr CreateCredentialsProviderImds( const CredentialsProviderImdsConfig &config, - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); /** * Creates a provider that sources credentials by querying a series of providers and @@ -306,7 +304,7 @@ namespace Aws */ static std::shared_ptr CreateCredentialsProviderChain( const CredentialsProviderChainConfig &config, - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); /* * Creates a provider that puts a simple time-based cache in front of its queries @@ -314,7 +312,7 @@ namespace Aws */ static std::shared_ptr CreateCredentialsProviderCached( const CredentialsProviderCachedConfig &config, - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); /** * Creates the SDK-standard default credentials provider which is a cache-fronted chain of: @@ -324,7 +322,7 @@ namespace Aws */ static std::shared_ptr CreateCredentialsProviderChainDefault( const CredentialsProviderChainDefaultConfig &config, - Allocator *allocator = DefaultAllocator()); + Allocator *allocator = g_allocator); private: static void s_onCredentialsResolved(aws_credentials *credentials, void *user_data); diff --git a/include/aws/crt/auth/Sigv4Signing.h b/include/aws/crt/auth/Sigv4Signing.h index b54d4e7ff..a203f28a5 100644 --- a/include/aws/crt/auth/Sigv4Signing.h +++ b/include/aws/crt/auth/Sigv4Signing.h @@ -56,7 +56,7 @@ namespace Aws class AWS_CRT_CPP_API AwsSigningConfig : public ISigningConfig { public: - AwsSigningConfig(Allocator *allocator = DefaultAllocator()); + AwsSigningConfig(Allocator *allocator = g_allocator); virtual ~AwsSigningConfig(); virtual SigningConfigType GetType() const noexcept override { return SigningConfigType::Aws; } @@ -178,7 +178,7 @@ namespace Aws class AWS_CRT_CPP_API Sigv4HttpRequestSigner : public IHttpRequestSigner { public: - Sigv4HttpRequestSigner(Allocator *allocator = DefaultAllocator()); + Sigv4HttpRequestSigner(Allocator *allocator = g_allocator); virtual ~Sigv4HttpRequestSigner() = default; bool IsValid() const override { return true; } diff --git a/include/aws/crt/crypto/Hash.h b/include/aws/crt/crypto/Hash.h index 86c9d7500..cf97a6f0d 100644 --- a/include/aws/crt/crypto/Hash.h +++ b/include/aws/crt/crypto/Hash.h @@ -94,12 +94,12 @@ namespace Aws /** * Creates an instance of a Streaming SHA256 Hash. */ - static Hash CreateSHA256(Allocator *allocator = DefaultAllocator()) noexcept; + static Hash CreateSHA256(Allocator *allocator = g_allocator) noexcept; /** * Creates an instance of a Streaming MD5 Hash. */ - static Hash CreateMD5(Allocator *allocator = DefaultAllocator()) noexcept; + static Hash CreateMD5(Allocator *allocator = g_allocator) noexcept; /** * Updates the running hash object with data in toHash. Returns true on success. Call diff --git a/include/aws/crt/http/HttpConnection.h b/include/aws/crt/http/HttpConnection.h index eb947c14d..d3f87e255 100644 --- a/include/aws/crt/http/HttpConnection.h +++ b/include/aws/crt/http/HttpConnection.h @@ -127,7 +127,7 @@ namespace Aws * Represents a single http message exchange (request/response) or in H2, it can also represent * a PUSH_PROMISE followed by the accompanying Response. */ - class AWS_CRT_CPP_API HttpStream + class AWS_CRT_CPP_API HttpStream : public std::enable_shared_from_this { public: virtual ~HttpStream(); @@ -184,10 +184,17 @@ namespace Aws friend class HttpClientConnection; }; + struct ClientStreamCallbackData + { + ClientStreamCallbackData() : allocator(nullptr), stream(nullptr) {} + Allocator *allocator; + std::shared_ptr stream; + }; + class AWS_CRT_CPP_API HttpClientStream final : public HttpStream { public: - ~HttpClientStream() = default; + ~HttpClientStream(); HttpClientStream(const HttpClientStream &) = delete; HttpClientStream(HttpClientStream &&) = delete; HttpClientStream &operator=(const HttpClientStream &) = delete; @@ -199,9 +206,17 @@ namespace Aws */ virtual int GetResponseStatusCode() const noexcept override; + /** + * Activates the request's outgoing stream processing. + * + * Returns true on success, false otherwise. + */ + bool Activate() noexcept; + private: HttpClientStream(const std::shared_ptr &connection) noexcept; + ClientStreamCallbackData m_callbackData; friend class HttpClientConnection; }; @@ -333,6 +348,15 @@ namespace Aws * Optional. */ Optional ProxyOptions; + + /** + * If set to true, then the TCP read back pressure mechanism will be enabled. You should + * only use this if you're allowing http response body data to escape the callbacks. E.g. you're + * putting the data into a queue for another thread to process and need to make sure the memory + * usage is bounded. If this is enabled, you must call HttpStream::UpdateWindow() for every + * byte read from the OnIncomingBody callback. + */ + bool ManualWindowManagement; }; /** @@ -356,6 +380,8 @@ namespace Aws * not be freed until the stream is completed. * * Returns an instance of HttpStream upon success and nullptr on failure. + * + * You must call HttpClientStream::Activate() to begin outgoing processing of the stream. */ std::shared_ptr NewClientStream(const HttpRequestOptions &requestOptions) noexcept; diff --git a/include/aws/crt/http/HttpConnectionManager.h b/include/aws/crt/http/HttpConnectionManager.h index 918c7115a..318cd14c8 100644 --- a/include/aws/crt/http/HttpConnectionManager.h +++ b/include/aws/crt/http/HttpConnectionManager.h @@ -99,12 +99,12 @@ namespace Aws */ static std::shared_ptr NewClientConnectionManager( const HttpClientConnectionManagerOptions &connectionManagerOptions, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; private: HttpClientConnectionManager( const HttpClientConnectionManagerOptions &options, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; Allocator *m_allocator; diff --git a/include/aws/crt/http/HttpRequestResponse.h b/include/aws/crt/http/HttpRequestResponse.h index f99625c77..b8fffe1ea 100644 --- a/include/aws/crt/http/HttpRequestResponse.h +++ b/include/aws/crt/http/HttpRequestResponse.h @@ -85,7 +85,7 @@ namespace Aws friend class Mqtt::MqttConnection; public: - HttpRequest(Allocator *allocator = DefaultAllocator()); + HttpRequest(Allocator *allocator = g_allocator); /** * Gets the value of the Http method associated with this request @@ -117,7 +117,7 @@ namespace Aws class AWS_CRT_CPP_API HttpResponse : public HttpMessage { public: - HttpResponse(Allocator *allocator = DefaultAllocator()); + HttpResponse(Allocator *allocator = g_allocator); /** * Gets the integral Http response code associated with this response diff --git a/include/aws/crt/io/Bootstrap.h b/include/aws/crt/io/Bootstrap.h index 55c96ed9f..f157cdbdd 100644 --- a/include/aws/crt/io/Bootstrap.h +++ b/include/aws/crt/io/Bootstrap.h @@ -34,7 +34,7 @@ namespace Aws ClientBootstrap( EventLoopGroup &elGroup, HostResolver &resolver, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; ~ClientBootstrap(); ClientBootstrap(const ClientBootstrap &) = delete; ClientBootstrap &operator=(const ClientBootstrap &) = delete; diff --git a/include/aws/crt/io/EventLoopGroup.h b/include/aws/crt/io/EventLoopGroup.h index 5e7b469f2..0e1f03bdb 100644 --- a/include/aws/crt/io/EventLoopGroup.h +++ b/include/aws/crt/io/EventLoopGroup.h @@ -38,8 +38,8 @@ namespace Aws class AWS_CRT_CPP_API EventLoopGroup final { public: - EventLoopGroup(Allocator *allocator = DefaultAllocator()) noexcept; - EventLoopGroup(uint16_t threadCount, Allocator *allocator = DefaultAllocator()) noexcept; + EventLoopGroup(Allocator *allocator = g_allocator) noexcept; + EventLoopGroup(uint16_t threadCount, Allocator *allocator = g_allocator) noexcept; ~EventLoopGroup(); EventLoopGroup(const EventLoopGroup &) = delete; EventLoopGroup(EventLoopGroup &&) noexcept; diff --git a/include/aws/crt/io/HostResolver.h b/include/aws/crt/io/HostResolver.h index 39660c705..be5f2b540 100644 --- a/include/aws/crt/io/HostResolver.h +++ b/include/aws/crt/io/HostResolver.h @@ -58,7 +58,7 @@ namespace Aws EventLoopGroup &elGroup, size_t maxHosts, size_t maxTTL, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; ~DefaultHostResolver(); DefaultHostResolver(const DefaultHostResolver &) = delete; DefaultHostResolver &operator=(const DefaultHostResolver &) = delete; diff --git a/include/aws/crt/io/Stream.h b/include/aws/crt/io/Stream.h index 752f1dc2a..3b611ad7f 100644 --- a/include/aws/crt/io/Stream.h +++ b/include/aws/crt/io/Stream.h @@ -57,7 +57,7 @@ namespace Aws Allocator *m_allocator; aws_input_stream m_underlying_stream; - InputStream(Aws::Crt::Allocator *allocator = DefaultAllocator()); + InputStream(Aws::Crt::Allocator *allocator = g_allocator); /*** * Read up-to buffer::capacity - buffer::len into buffer::buffer @@ -109,7 +109,7 @@ namespace Aws public: StdIOStreamInputStream( std::shared_ptr stream, - Aws::Crt::Allocator *allocator = DefaultAllocator()) noexcept; + Aws::Crt::Allocator *allocator = g_allocator) noexcept; bool IsValid() const noexcept override; diff --git a/include/aws/crt/io/TlsOptions.h b/include/aws/crt/io/TlsOptions.h index 402a87fef..030e86f86 100644 --- a/include/aws/crt/io/TlsOptions.h +++ b/include/aws/crt/io/TlsOptions.h @@ -50,7 +50,7 @@ namespace Aws * Initializes TlsContextOptions with secure by default options, with * no client certificates. */ - static TlsContextOptions InitDefaultClient(Allocator *allocator = DefaultAllocator()) noexcept; + static TlsContextOptions InitDefaultClient(Allocator *allocator = g_allocator) noexcept; /** * Initializes TlsContextOptions with secure by default options, with * client certificate and private key. These are paths to a file on disk. These files @@ -59,7 +59,7 @@ namespace Aws static TlsContextOptions InitClientWithMtls( const char *cert_path, const char *pkey_path, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; /** * Initializes TlsContextOptions with secure by default options, with @@ -69,7 +69,7 @@ namespace Aws static TlsContextOptions InitClientWithMtls( const ByteCursor &cert, const ByteCursor &pkey, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; #ifdef __APPLE__ /** @@ -81,7 +81,7 @@ namespace Aws static TlsContextOptions InitClientWithMtlsPkcs12( const char *pkcs12_path, const char *pkcs12_pwd, - Allocator *allocator = DefaultAllocator()) noexcept; + Allocator *allocator = g_allocator) noexcept; #endif /** @@ -172,10 +172,7 @@ namespace Aws { public: TlsContext() noexcept; - TlsContext( - TlsContextOptions &options, - TlsMode mode, - Allocator *allocator = DefaultAllocator()) noexcept; + TlsContext(TlsContextOptions &options, TlsMode mode, Allocator *allocator = g_allocator) noexcept; ~TlsContext() = default; TlsContext(const TlsContext &) noexcept = default; TlsContext &operator=(const TlsContext &) noexcept = default; diff --git a/include/aws/crt/io/Uri.h b/include/aws/crt/io/Uri.h index 09fc796cd..245b79b20 100644 --- a/include/aws/crt/io/Uri.h +++ b/include/aws/crt/io/Uri.h @@ -35,12 +35,12 @@ namespace Aws * Parses `cursor` as a URI. Upon failure the bool() operator will return false and LastError() * will contain the errorCode. */ - Uri(const ByteCursor &cursor, Allocator *allocator = DefaultAllocator()) noexcept; + Uri(const ByteCursor &cursor, Allocator *allocator = g_allocator) noexcept; /** * builds a URI from `builderOptions`. Upon failure the bool() operator will return false and * LastError() will contain the errorCode. */ - Uri(aws_uri_builder_options &builderOptions, Allocator *allocator = DefaultAllocator()) noexcept; + Uri(aws_uri_builder_options &builderOptions, Allocator *allocator = g_allocator) noexcept; Uri(const Uri &); Uri &operator=(const Uri &); Uri(Uri &&uri) noexcept; diff --git a/include/aws/crt/mqtt/MqttClient.h b/include/aws/crt/mqtt/MqttClient.h index f24259093..c2c6f5db6 100644 --- a/include/aws/crt/mqtt/MqttClient.h +++ b/include/aws/crt/mqtt/MqttClient.h @@ -300,7 +300,7 @@ namespace Aws /** * Initialize an MqttClient using bootstrap and allocator */ - MqttClient(Io::ClientBootstrap &bootstrap, Allocator *allocator = DefaultAllocator()) noexcept; + MqttClient(Io::ClientBootstrap &bootstrap, Allocator *allocator = g_allocator) noexcept; ~MqttClient(); MqttClient(const MqttClient &) = delete; diff --git a/include/aws/iot/MqttClient.h b/include/aws/iot/MqttClient.h index 86d59150a..f6bdef20f 100644 --- a/include/aws/iot/MqttClient.h +++ b/include/aws/iot/MqttClient.h @@ -86,7 +86,7 @@ namespace Aws WebsocketConfig( const Crt::String &signingRegion, Crt::Io::ClientBootstrap *bootstrap, - Crt::Allocator *allocator = Crt::DefaultAllocator()) noexcept; + Crt::Allocator *allocator = Crt::g_allocator) noexcept; /** * Create a websocket configuration for use with a custom credentials provider. Signing region will be use @@ -95,7 +95,7 @@ namespace Aws WebsocketConfig( const Crt::String &signingRegion, const std::shared_ptr &credentialsProvider, - Crt::Allocator *allocator = Crt::DefaultAllocator()) noexcept; + Crt::Allocator *allocator = Crt::g_allocator) noexcept; /** * Create a websocket configuration for use with a custom credentials provider, and a custom signer. @@ -140,7 +140,7 @@ namespace Aws MqttClientConnectionConfigBuilder( const char *certPath, const char *pkeyPath, - Crt::Allocator *allocator = Crt::DefaultAllocator()) noexcept; + Crt::Allocator *allocator = Crt::g_allocator) noexcept; /** * Sets the builder up for MTLS using cert and pkey. These are in-memory buffers and must be in the PEM @@ -149,14 +149,14 @@ namespace Aws MqttClientConnectionConfigBuilder( const Crt::ByteCursor &cert, const Crt::ByteCursor &pkey, - Crt::Allocator *allocator = Crt::DefaultAllocator()) noexcept; + Crt::Allocator *allocator = Crt::g_allocator) noexcept; /** * Sets the builder up for Websocket connection. */ MqttClientConnectionConfigBuilder( const WebsocketConfig &config, - Crt::Allocator *allocator = Crt::DefaultAllocator()) noexcept; + Crt::Allocator *allocator = Crt::g_allocator) noexcept; /** * Sets endpoint to connect to. @@ -234,9 +234,7 @@ namespace Aws class AWS_CRT_CPP_API MqttClient final { public: - MqttClient( - Crt::Io::ClientBootstrap &bootstrap, - Crt::Allocator *allocator = Crt::DefaultAllocator()) noexcept; + MqttClient(Crt::Io::ClientBootstrap &bootstrap, Crt::Allocator *allocator = Crt::g_allocator) noexcept; std::shared_ptr NewConnection(const MqttClientConnectionConfig &config) noexcept; diff --git a/source/crypto/HMAC.cpp b/source/crypto/HMAC.cpp index 978c1ff74..215cdbd69 100644 --- a/source/crypto/HMAC.cpp +++ b/source/crypto/HMAC.cpp @@ -38,8 +38,7 @@ namespace Aws ByteBuf &output, size_t truncateTo) noexcept { - return aws_sha256_hmac_compute(DefaultAllocator(), &secret, &input, &output, truncateTo) == - AWS_OP_SUCCESS; + return aws_sha256_hmac_compute(g_allocator, &secret, &input, &output, truncateTo) == AWS_OP_SUCCESS; } HMAC::HMAC(aws_hmac *hmac) noexcept : m_hmac(hmac), m_good(false), m_lastError(0) @@ -86,7 +85,7 @@ namespace Aws HMAC HMAC::CreateSHA256HMAC(const ByteCursor &secret) noexcept { - return HMAC(aws_sha256_hmac_new(DefaultAllocator(), &secret)); + return HMAC(aws_sha256_hmac_new(g_allocator, &secret)); } bool HMAC::Update(const ByteCursor &toHMAC) noexcept diff --git a/source/http/HttpConnection.cpp b/source/http/HttpConnection.cpp index 3d5bda240..118bcc13e 100644 --- a/source/http/HttpConnection.cpp +++ b/source/http/HttpConnection.cpp @@ -130,6 +130,7 @@ namespace Aws options.socket_options = &connectionOptions.SocketOptions.GetImpl(); options.on_setup = HttpClientConnection::s_onClientConnectionSetup; options.on_shutdown = HttpClientConnection::s_onClientConnectionShutdown; + options.manual_window_management = connectionOptions.ManualWindowManagement; if (aws_http_client_connect(&options)) { @@ -145,12 +146,6 @@ namespace Aws { } - struct ClientStreamCallbackData - { - Allocator *allocator; - std::shared_ptr stream; - }; - std::shared_ptr HttpClientConnection::NewClientStream( const HttpRequestOptions &requestOptions) noexcept { @@ -168,41 +163,38 @@ namespace Aws /* Do the same ref counting trick we did with HttpClientConnection. We need to maintain a reference * internally (regardless of what the user does), until the Stream shuts down. */ - auto *toSeat = static_cast(aws_mem_acquire(m_allocator, sizeof(HttpStream))); + auto *toSeat = static_cast(aws_mem_acquire(m_allocator, sizeof(HttpClientStream))); if (toSeat) { - auto *callbackData = New(m_allocator); - if (!callbackData) - { - aws_mem_release(m_allocator, toSeat); - return nullptr; - } - toSeat = new (toSeat) HttpClientStream(this->shared_from_this()); Allocator *captureAllocator = m_allocator; - callbackData->stream = std::shared_ptr( - toSeat, [captureAllocator](HttpStream *stream) { Delete(stream, captureAllocator); }); - - toSeat->m_onIncomingBody = requestOptions.onIncomingBody; - toSeat->m_onIncomingHeaders = requestOptions.onIncomingHeaders; - toSeat->m_onIncomingHeadersBlockDone = requestOptions.onIncomingHeadersBlockDone; - toSeat->m_onStreamComplete = requestOptions.onStreamComplete; - - callbackData->allocator = m_allocator; - options.user_data = callbackData; - toSeat->m_stream = aws_http_connection_make_request(m_connection, &options); - - if (!toSeat->m_stream) + std::shared_ptr stream( + toSeat, + [captureAllocator](HttpStream *stream) { Delete(stream, captureAllocator); }, + StlAllocator(captureAllocator)); + + stream->m_onIncomingBody = requestOptions.onIncomingBody; + stream->m_onIncomingHeaders = requestOptions.onIncomingHeaders; + stream->m_onIncomingHeadersBlockDone = requestOptions.onIncomingHeadersBlockDone; + stream->m_onStreamComplete = requestOptions.onStreamComplete; + stream->m_callbackData.allocator = m_allocator; + + // we purposefully do not set m_callbackData::stream because we don't want the reference count + // incremented until the request is kicked off via HttpClientStream::Activate(). Activate() + // increments the ref count. + options.user_data = &stream->m_callbackData; + stream->m_stream = aws_http_connection_make_request(m_connection, &options); + + if (!stream->m_stream) { - callbackData->stream = nullptr; - Delete(callbackData, m_allocator); + stream = nullptr; m_lastError = aws_last_error(); return nullptr; } - return callbackData->stream; + return stream; } m_lastError = aws_last_error(); @@ -260,9 +252,7 @@ namespace Aws { auto callbackData = static_cast(userData); callbackData->stream->m_onStreamComplete(*callbackData->stream, errorCode); - callbackData->stream = nullptr; - Delete(callbackData, callbackData->allocator); } HttpStream::HttpStream(const std::shared_ptr &connection) noexcept @@ -290,6 +280,8 @@ namespace Aws { } + HttpClientStream::~HttpClientStream() {} + int HttpClientStream::GetResponseStatusCode() const noexcept { int status = 0; @@ -301,6 +293,18 @@ namespace Aws return -1; } + bool HttpClientStream::Activate() noexcept + { + m_callbackData.stream = shared_from_this(); + if (aws_http_stream_activate(m_stream)) + { + m_callbackData.stream = nullptr; + return false; + } + + return true; + } + void HttpStream::UpdateWindow(std::size_t incrementSize) noexcept { aws_http_stream_update_window(m_stream, incrementSize); @@ -314,7 +318,8 @@ namespace Aws HttpClientConnectionOptions::HttpClientConnectionOptions() : Bootstrap(nullptr), InitialWindowSize(SIZE_MAX), OnConnectionSetupCallback(), - OnConnectionShutdownCallback(), HostName(), Port(0), SocketOptions(), TlsOptions(), ProxyOptions() + OnConnectionShutdownCallback(), HostName(), Port(0), SocketOptions(), TlsOptions(), ProxyOptions(), + ManualWindowManagement(false) { } } // namespace Http diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c6ee051f9..9445c15fd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,6 +21,7 @@ add_test_case(SHA256ResourceSafety) add_test_case(MD5ResourceSafety) add_test_case(SHA256HMACResourceSafety) add_net_test_case(HttpDownloadNoBackPressure) +add_net_test_case(HttpStreamUnActivated) add_net_test_case(IotPublishSubscribe) add_net_test_case(HttpClientConnectionManagerResourceSafety) add_net_test_case(HttpClientConnectionWithPendingAcquisitions) diff --git a/tests/HttpClientTest.cpp b/tests/HttpClientTest.cpp index 40380cb0e..df5d5a400 100644 --- a/tests/HttpClientTest.cpp +++ b/tests/HttpClientTest.cpp @@ -189,7 +189,9 @@ static int s_TestHttpDownloadNoBackPressure(struct aws_allocator *allocator, voi host_header.value = uri.GetHostName(); request.AddHeader(host_header); - connection->NewClientStream(requestOptions); + auto stream = connection->NewClientStream(requestOptions); + ASSERT_TRUE(stream->Activate()); + semaphore.wait(semaphoreULock, [&]() { return streamCompleted; }); ASSERT_INT_EQUALS(200, responseCode); @@ -202,3 +204,118 @@ static int s_TestHttpDownloadNoBackPressure(struct aws_allocator *allocator, voi } AWS_TEST_CASE(HttpDownloadNoBackPressure, s_TestHttpDownloadNoBackPressure) + +static int s_TestHttpStreamUnActivated(struct aws_allocator *allocator, void *ctx) +{ + (void)ctx; + Aws::Crt::ApiHandle apiHandle(allocator); + Aws::Crt::Io::TlsContextOptions tlsCtxOptions = Aws::Crt::Io::TlsContextOptions::InitDefaultClient(); + Aws::Crt::Io::TlsContext tlsContext(tlsCtxOptions, Aws::Crt::Io::TlsMode::CLIENT, allocator); + ASSERT_TRUE(tlsContext); + + Aws::Crt::Io::TlsConnectionOptions tlsConnectionOptions = tlsContext.NewConnectionOptions(); + + ByteCursor cursor = ByteCursorFromCString("https://aws-crt-test-stuff.s3.amazonaws.com/http_test_doc.txt"); + Io::Uri uri(cursor, allocator); + + auto hostName = uri.GetHostName(); + tlsConnectionOptions.SetServerName(hostName); + + Aws::Crt::Io::SocketOptions socketOptions; + socketOptions.SetConnectTimeoutMs(1000); + + Aws::Crt::Io::EventLoopGroup eventLoopGroup(0, allocator); + ASSERT_TRUE(eventLoopGroup); + + Aws::Crt::Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 8, 30, allocator); + ASSERT_TRUE(defaultHostResolver); + + Aws::Crt::Io::ClientBootstrap clientBootstrap(eventLoopGroup, defaultHostResolver, allocator); + ASSERT_TRUE(clientBootstrap); + + std::shared_ptr connection(nullptr); + bool errorOccured = true; + bool connectionShutdown = false; + + std::condition_variable semaphore; + std::mutex semaphoreLock; + + auto onConnectionSetup = [&](const std::shared_ptr &newConnection, int errorCode) { + std::lock_guard lockGuard(semaphoreLock); + + if (!errorCode) + { + connection = newConnection; + errorOccured = false; + } + else + { + connectionShutdown = true; + } + + semaphore.notify_one(); + }; + + auto onConnectionShutdown = [&](Http::HttpClientConnection &newConnection, int errorCode) { + std::lock_guard lockGuard(semaphoreLock); + + connectionShutdown = true; + if (errorCode) + { + errorOccured = true; + } + + semaphore.notify_one(); + }; + + Http::HttpClientConnectionOptions httpClientConnectionOptions; + httpClientConnectionOptions.Bootstrap = &clientBootstrap; + httpClientConnectionOptions.OnConnectionSetupCallback = onConnectionSetup; + httpClientConnectionOptions.OnConnectionShutdownCallback = onConnectionShutdown; + httpClientConnectionOptions.SocketOptions = socketOptions; + httpClientConnectionOptions.TlsOptions = tlsConnectionOptions; + httpClientConnectionOptions.HostName = String((const char *)hostName.ptr, hostName.len); + httpClientConnectionOptions.Port = 443; + + std::unique_lock semaphoreULock(semaphoreLock); + ASSERT_TRUE(Http::HttpClientConnection::CreateConnection(httpClientConnectionOptions, allocator)); + semaphore.wait(semaphoreULock, [&]() { return connection || connectionShutdown; }); + + ASSERT_FALSE(errorOccured); + ASSERT_FALSE(connectionShutdown); + ASSERT_TRUE(connection); + + Http::HttpRequest request; + Http::HttpRequestOptions requestOptions; + requestOptions.request = &request; + + requestOptions.onStreamComplete = [&](Http::HttpStream &, int) { + // do nothing. + }; + requestOptions.onIncomingHeadersBlockDone = nullptr; + requestOptions.onIncomingHeaders = + [&](Http::HttpStream &, enum aws_http_header_block, const Http::HttpHeader *, std::size_t) { + // do nothing + }; + requestOptions.onIncomingBody = [&](Http::HttpStream &, const ByteCursor &) { + // do nothing + }; + + request.SetMethod(ByteCursorFromCString("GET")); + request.SetPath(uri.GetPathAndQuery()); + + Http::HttpHeader host_header; + host_header.name = ByteCursorFromCString("host"); + host_header.value = uri.GetHostName(); + request.AddHeader(host_header); + + // don't activate it and let it go out of scope. + auto stream = connection->NewClientStream(requestOptions); + stream = nullptr; + connection->Close(); + semaphore.wait(semaphoreULock, [&]() { return connectionShutdown; }); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(HttpStreamUnActivated, s_TestHttpStreamUnActivated)