From 7cd634f614caa152419244d9c21223aa673ef6ad Mon Sep 17 00:00:00 2001 From: Joe Darby Date: Wed, 27 Dec 2023 12:16:08 +0000 Subject: [PATCH] Upgrade to cats effect 3 (#853) * Upgrade to cats effect 3 * Minor adjustments * Fix cross compilation * Fix it tests * Fix it test assertion * Bump http4s version --- build.sbt | 16 ++- .../src/it/scala/auth/AwsSignerItSpec.scala | 89 +++++------- .../auth/src/main/scala/auth/AwsSigner.scala | 52 +++---- .../scala/auth/AwsSignV4TestSuiteSpec.scala | 19 +-- .../src/test/scala/auth/AwsSignerSpec.scala | 135 +++++++++--------- .../src/it/scala/common/IntegrationSpec.scala | 4 +- .../src/main/scala/common/HttpCodecs.scala | 2 +- .../src/main/scala/common/headers.scala | 103 +++++++------ .../src/test/scala/common/IOFutures.scala | 43 ------ .../src/test/scala/common/UnitSpec.scala | 8 +- modules/s3/src/it/scala/s3/S3Spec.scala | 89 +++++------- modules/s3/src/main/scala/s3/S3.scala | 65 ++++----- modules/s3/src/main/scala/s3/headers.scala | 26 ++-- modules/s3/src/main/scala/s3/model.scala | 8 +- .../test/scala/s3/utils/S3UriParserSpec.scala | 25 ++-- project/plugins.sbt | 1 + 16 files changed, 304 insertions(+), 381 deletions(-) delete mode 100644 modules/common/src/test/scala/common/IOFutures.scala diff --git a/build.sbt b/build.sbt index aed54ae7..6f39e85b 100644 --- a/build.sbt +++ b/build.sbt @@ -2,16 +2,17 @@ import sbtrelease.ExtraReleaseCommands import sbtrelease.ReleaseStateTransformations._ import sbtrelease.tagsonly.TagsOnly._ -lazy val fs2Version = "2.5.11" -lazy val catsEffectVersion = "2.5.1" +lazy val fs2Version = "3.9.3" +lazy val catsEffectVersion = "3.5.2" lazy val scalatestVersion = "3.2.0" lazy val awsSdkVersion = "2.21.43" lazy val scalacheckVersion = "1.17.0" lazy val scalatestScalacheckVersion = "3.1.1.1" lazy val slf4jVersion = "1.7.32" lazy val log4jVersion = "2.22.0" -lazy val http4sVersion = "0.21.34" -lazy val scalaXmlVersion = "1.3.0" +lazy val http4sVersion = "0.23.24" +lazy val http4sBlazeClientVersion = "0.23.15" +lazy val scalaXmlVersion = "2.1.0" lazy val circeVersion = "0.12.2" lazy val scodecBitsVersion = "1.1.12" lazy val commonCodecVersion = "1.14" @@ -105,11 +106,12 @@ lazy val root = (project in file(".")) ), libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % scalatestVersion, + "org.typelevel" %% "cats-effect-testing-scalatest" % "1.5.0", "org.scalacheck" %% "scalacheck" % scalacheckVersion, "org.scalatestplus" %% "scalacheck-1-14" % scalatestScalacheckVersion, "org.apache.logging.log4j" % "log4j-api" % log4jVersion, "org.apache.logging.log4j" % "log4j-slf4j-impl" % log4jVersion, - "org.http4s" %% "http4s-blaze-client" % http4sVersion + "org.http4s" %% "http4s-blaze-client" % http4sBlazeClientVersion ).map(_ % s"$Test,$IntegrationTest"), scalafmtOnCompile := true ) @@ -164,9 +166,9 @@ lazy val s3 = (project in file("modules/s3")) .settings(automateHeaderSettings(IntegrationTest)) .settings( libraryDependencies ++= Seq( - "org.http4s" %% "http4s-scala-xml" % http4sVersion, + "org.http4s" %% "http4s-scala-xml" % "0.23.13", "org.scala-lang.modules" %% "scala-xml" % scalaXmlVersion, - "org.http4s" %% "http4s-blaze-client" % http4sVersion % Optional, + "org.http4s" %% "http4s-blaze-client" % http4sBlazeClientVersion % Optional, "software.amazon.awssdk" % "s3" % awsSdkVersion % s"$Test,$IntegrationTest" ) ) diff --git a/modules/auth/src/it/scala/auth/AwsSignerItSpec.scala b/modules/auth/src/it/scala/auth/AwsSignerItSpec.scala index 93384217..3bf3ec4d 100644 --- a/modules/auth/src/it/scala/auth/AwsSignerItSpec.scala +++ b/modules/auth/src/it/scala/auth/AwsSignerItSpec.scala @@ -3,23 +3,20 @@ package auth import common._ import common.model._ -import cats.effect.{ContextShift, IO} - +import cats.effect.IO +import cats.effect.testing.scalatest.AsyncIOSpec import org.http4s.client.Client -import org.http4s.client.blaze.BlazeClientBuilder +import org.http4s.blaze.client.BlazeClientBuilder import org.http4s.client.dsl.Http4sClientDsl import org.http4s.Method._ import org.http4s.headers._ -import org.http4s.{MediaType, Status, Uri} +import org.http4s.{MediaType, Request, Status, Uri} import org.http4s.client.middleware.{RequestLogger, ResponseLogger} -import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.global import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain -class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { +class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] with AsyncIOSpec { - implicit val ctx: ContextShift[IO] = IO.contextShift(scala.concurrent.ExecutionContext.global) // This is our UAT environment private val esEndpoint = "" @@ -40,15 +37,12 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { val signedClient: Client[IO] = awsSigner(requestLogger(responseLogger(client))) - for { - req <- GET( - Uri.unsafeFromString("https://s3-eu-west-1.amazonaws.com/ovo-comms-test/more.pdf") - ) - status <- signedClient.status(req) - } yield { - status.isSuccess shouldBe true - } - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) + val req = Request[IO]( + uri = Uri.unsafeFromString("https://s3-eu-west-1.amazonaws.com/ovo-comms-test/more.pdf") + ) + + signedClient.status(req).map(_.isSuccess) + }.asserting(_ shouldBe true) } "sign request valid for S3 with nested paths" in { @@ -66,13 +60,11 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { val signedClient: Client[IO] = awsSigner(requestLogger(responseLogger(client))) - for { - req <- GET( - Uri.unsafeFromString("https://s3-eu-west-1.amazonaws.com/ovo-comms-test/test/more.pdf") - ) - status <- signedClient.status(req) - } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + val req = Request[IO]( uri = Uri.unsafeFromString("https://s3-eu-west-1.amazonaws.com/ovo-comms-test/test/more.pdf") + ) + + signedClient.status(req) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES GET" ignore { @@ -91,10 +83,9 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { val signedClient: Client[IO] = awsSigner(requestLogger(responseLogger(client))) for { - req <- GET(Uri.unsafeFromString(s"$esEndpoint/audit-2018-09/_doc/foo")) - status <- signedClient.status(req) + status <- signedClient.status(GET(Uri.unsafeFromString(s"$esEndpoint/audit-2018-09/_doc/foo"))) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES POST" ignore { @@ -123,14 +114,13 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { """ for { - req <- POST( + status <- signedClient.status(POST( body, Uri.unsafeFromString(s"$esEndpoint/audit-2018-09/_doc/_search"), `Content-Type`(MediaType.application.json) - ) - status <- signedClient.status(req) + )) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES POST with multiple indexes" ignore { @@ -159,16 +149,15 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { """ for { - req <- POST( + status <- signedClient.status(POST( body, Uri.unsafeFromString( s"$esEndpoint/audit-2018-09,audit-2018-10,audit-2018-11/_doc/_search?ignore_unavailable=true" ), `Content-Type`(MediaType.application.json) - ) - status <- signedClient.status(req) + )) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES POST with query" ignore { @@ -197,16 +186,15 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { """ for { - req <- POST( + status <- signedClient.status(POST( body, Uri.unsafeFromString( s"$esEndpoint/audit-2018-09/_doc/_search?ignore_unavailable=true" ), `Content-Type`(MediaType.application.json) - ) - status <- signedClient.status(req) + )) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES POST with query and multiple parameters" ignore { @@ -235,16 +223,15 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { """ for { - req <- POST( + status <- signedClient.status(POST( body, Uri.unsafeFromString( s"$esEndpoint/audit-2018-09/_doc/_search?ignore_unavailable=true&refresh=true" ), `Content-Type`(MediaType.application.json) - ) - status <- signedClient.status(req) + )) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES POST with query and multiple parameters with comas and stars" ignore { @@ -273,16 +260,15 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { """ for { - req <- POST( + status <- signedClient.status(POST( body, Uri.unsafeFromString( "/audit-2018-09/_doc/_search?ignore_unavailable=true&refresh=true&foo*=foo&bar,baz=baz" ), `Content-Type`(MediaType.application.json) - ) - status <- signedClient.status(req) + )) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } "sign request valid for ES POST with star in path" ignore { @@ -311,19 +297,18 @@ class AwsSignerItSpec extends IntegrationSpec with Http4sClientDsl[IO] { """ for { - req <- POST( + status <- signedClient.status(POST( body, Uri.unsafeFromString(s"$esEndpoint/audit-*/_doc/_search"), `Content-Type`(MediaType.application.json) - ) - status <- signedClient.status(req) + )) } yield status - }.futureValue(timeout(scaled(5.seconds)), interval(500.milliseconds)) should (not be Status.Unauthorized and not be Status.Forbidden) + }.asserting(_ should (not be Status.Unauthorized and not be Status.Forbidden)) } } def withHttpClient[A](f: Client[IO] => IO[A]): IO[A] = { - BlazeClientBuilder[IO](global).resource + BlazeClientBuilder[IO].resource .use(f) } diff --git a/modules/auth/src/main/scala/auth/AwsSigner.scala b/modules/auth/src/main/scala/auth/AwsSigner.scala index 1f3d31ce..483b689f 100644 --- a/modules/auth/src/main/scala/auth/AwsSigner.scala +++ b/modules/auth/src/main/scala/auth/AwsSigner.scala @@ -20,8 +20,9 @@ package auth import AwsSigner._ import common._ import common.model._ -import headers.{`X-Amz-Content-SHA256`, `X-Amz-Security-Token`, `X-Amz-Date`} -import cats.effect.{Sync, Resource} +import headers.{`X-Amz-Content-SHA256`, `X-Amz-Date`, `X-Amz-Security-Token`} +import headers.`X-Amz-Date`._ +import cats.effect.{Resource, Sync} import cats.implicits._ import scala.util.matching.Regex @@ -31,19 +32,18 @@ import java.security.MessageDigest import java.time._ import java.time.format.DateTimeFormatter import java.time.temporal.ChronoUnit - import org.slf4j.LoggerFactory import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec import fs2.hash._ -import org.http4s.{Request, HttpDate, Response} +import org.http4s.{HttpDate, Request, Response, Uri} import org.http4s.Header.Raw import org.http4s.client.Client import org.http4s.headers.{Date, Host} -import org.http4s.syntax.all._ - import org.apache.commons.codec.binary.Hex +import org.http4s.Header.Select.singleHeaders +import org.typelevel.ci._ object AwsSigner { @@ -53,7 +53,7 @@ object AwsSigner { DateTimeFormatter.ofPattern("yyyyMMdd") val dateTimeFormatter: DateTimeFormatter = - DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'") + DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmssX") val DoubleSlashRegex: Regex = "/{2,}".r val MultipleSpaceRegex: Regex = "\\s+".r @@ -124,9 +124,9 @@ object AwsSigner { def extractXAmzDateOrDate[F[_]](request: Request[F]): Option[Instant] = { request.headers - .get(`X-Amz-Date`) + .get[`X-Amz-Date`] .map(_.date) - .orElse(request.headers.get(Date).map(_.date)) + .orElse(request.headers.get[Date].map(_.date)) .map(_.toInstant) } @@ -140,7 +140,7 @@ object AwsSigner { extractXAmzDateOrDate(request).getOrElse(fallbackRequestDateTime) def addHostHeader(r: Request[F]): F[Request[F]] = - if (r.headers.get(Host).isEmpty) { + if (r.headers.get[Host].isEmpty) { val uri = r.uri F.fromOption( uri.host, @@ -154,7 +154,7 @@ object AwsSigner { } def addXAmzDateHeader(r: Request[F]): F[Request[F]] = - (if (r.headers.get(Date).isEmpty && r.headers.get(`X-Amz-Date`).isEmpty) { + (if (r.headers.get[Date].isEmpty && r.headers.get[`X-Amz-Date`].isEmpty) { r.putHeaders(`X-Amz-Date`(HttpDate.unsafeFromInstant(requestDateTime))) } else { r @@ -166,7 +166,7 @@ object AwsSigner { .pure[F] def addHashedBody(r: Request[F]): F[Request[F]] = - if (r.headers.get(`X-Amz-Content-SHA256`).isEmpty) { + if (r.headers.get[`X-Amz-Content-SHA256`].isEmpty) { // TODO Add chunking support for S3 def unChunk(request: Request[F]): F[Request[F]] = @@ -207,7 +207,7 @@ object AwsSigner { val hashedPayloadF: F[String] = { val headerValue = request.headers - .get(`X-Amz-Content-SHA256`) + .get[`X-Amz-Content-SHA256`] .map(_.hashedContent) headerValue.fold(hashBody(request))(_.pure[F]) @@ -237,7 +237,7 @@ object AwsSigner { val (canonicalHeaders, signedHeaders) = { - val grouped = request.headers.toList.groupBy(_.name) + val grouped = request.headers.headers.groupBy(_.name) val combined = grouped.mapValues( _.map(h => MultipleSpaceRegex.replaceAllIn(h.value, " ").trim) .mkString(",") @@ -245,13 +245,12 @@ object AwsSigner { val canonical = combined.toSeq .sortBy(_._1) - .map { case (k, v) => s"${k.value.toLowerCase}:$v\n" } + .map { case (k, v) => s"${k.toString.toLowerCase}:$v\n" } .mkString("") val signed: String = - request.headers.toList - .map(_.name.value.toLowerCase) - .toSeq + request.headers.headers + .map(_.name.toString.toLowerCase) .distinct .sorted .mkString(";") @@ -264,18 +263,21 @@ object AwsSigner { val method = request.method.name.toUpperCase val canonicalUri = { - val absolutePath = - if (request.uri.path.startsWith("/")) request.uri.path - else "/" ++ request.uri.path + val absolutePath = { + if (request.uri.path.startsWithString("/")) request.uri.path + else Uri.Path.unsafeFromString("/").concat(request.uri.path) + } // you do not normalize URI paths for requests to Amazon S3 val normalizedPath = if (service != Service.S3) { - DoubleSlashRegex.replaceAllIn(absolutePath, "/") + DoubleSlashRegex.replaceAllIn(absolutePath.renderString, "/") } else { - absolutePath + absolutePath.renderString } - val encodedOnceSegments = normalizedPath + val handleEmptyPath = if (normalizedPath.isEmpty) "/" else normalizedPath + + val encodedOnceSegments = handleEmptyPath .split("/", -1) .map(uriEncode) @@ -335,7 +337,7 @@ object AwsSigner { val authorizationHeaderValue = s"$algorithm Credential=${credentials.accessKeyId.value}/$scope, SignedHeaders=$signedHeaders, Signature=$signature" - Raw("Authorization".ci, authorizationHeaderValue) + Raw(ci"Authorization", authorizationHeaderValue) } request.putHeaders(authorizationHeader) diff --git a/modules/auth/src/test/scala/auth/AwsSignV4TestSuiteSpec.scala b/modules/auth/src/test/scala/auth/AwsSignV4TestSuiteSpec.scala index f75f6a42..6c151a9e 100644 --- a/modules/auth/src/test/scala/auth/AwsSignV4TestSuiteSpec.scala +++ b/modules/auth/src/test/scala/auth/AwsSignV4TestSuiteSpec.scala @@ -21,13 +21,14 @@ import common._ import headers._ import model._ import Credentials._ - +import cats.effect.testing.scalatest.AsyncIOSpec import cats.effect.{IO, Sync} - import org.http4s.client.dsl.Http4sClientDsl import org.http4s.Method._ +import org.http4s.headers.Authorization import org.http4s.{HttpDate, Request, Uri} import org.http4s.syntax.all._ +import org.typelevel.ci.CIStringSyntax /* Validate AwsSigner with test cases from 'AWS Signature Version 4 Test Suite' @@ -35,7 +36,7 @@ import org.http4s.syntax.all._ Download the test cases zip file, extract in the project root dir and run tests. */ -class AwsSignV4TestSuiteSpec extends UnitSpec with Http4sClientDsl[IO] { +class AwsSignV4TestSuiteSpec extends UnitSpec with Http4sClientDsl[IO] with AsyncIOSpec { import cats.data.EitherT val DateFormatter = @@ -80,13 +81,13 @@ class AwsSignV4TestSuiteSpec extends UnitSpec with Http4sClientDsl[IO] { testCase <- EitherT(getTestCase[IO](testFile.getAbsolutePath)) (request, expectedSignature) = testCase res <- EitherT.right[String](withSignRequest(IO(request)) { signed => - val signature = signed.headers.get("Authorization".ci).get.value - IO(signature shouldBe expectedSignature) + val signature = signed.headers.get("Authorization".ci).get.map(_.value) + IO(signature.toList shouldBe List(expectedSignature)) }) } yield res).value.map { case Left(msg) => fail(msg) case Right(a) => a - }.futureValue + } } } } @@ -122,8 +123,10 @@ class AwsSignV4TestSuiteSpec extends UnitSpec with Http4sClientDsl[IO] { def headers(rows: List[String]) = rows.map(_.split(":", 2).toList).collect { case "X-Amz-Date" :: v :: Nil => - `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(parseTestCaseDate(v))) - case k :: v :: Nil => Header(k, v) + Header.ToRaw.modelledHeadersToRaw( + `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(parseTestCaseDate(v))) + ) + case k :: v :: Nil => Header.ToRaw.rawToRaw(Header.Raw(k.ci, v)) } (requestText match { case RequestRe(requestSection, _, body) => diff --git a/modules/auth/src/test/scala/auth/AwsSignerSpec.scala b/modules/auth/src/test/scala/auth/AwsSignerSpec.scala index 1b1d0543..5975c36c 100644 --- a/modules/auth/src/test/scala/auth/AwsSignerSpec.scala +++ b/modules/auth/src/test/scala/auth/AwsSignerSpec.scala @@ -25,7 +25,6 @@ import Credentials._ import java.security.MessageDigest import java.time.{Instant, LocalDateTime, ZoneOffset} import java.time.temporal.ChronoUnit - import cats.implicits._ import cats.effect.IO import fs2._ @@ -33,11 +32,14 @@ import fs2.hash._ import org.http4s.client.dsl.Http4sClientDsl import org.http4s.Method._ import org.http4s.headers._ -import org.http4s.{HttpDate, MediaType, Request} +import org.http4s.{Headers, HttpDate, MediaType, Request} import AwsSigner._ +import cats.effect.testing.scalatest.AsyncIOSpec +import org.http4s.Header.Select.singleHeaders import org.http4s.syntax.all._ +import org.typelevel.ci.CIStringSyntax -class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { +class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] with AsyncIOSpec { "digest" should { "calculate the correct digest" in forAll() { data: Array[Byte] => @@ -66,9 +68,9 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { "Request with no body" should { "not have empty body stream" in { (for { - req <- GET.apply(uri"https://example.com") + req <- IO(GET.apply(uri"https://example.com")) last <- req.body.compile.last - } yield last).futureValue shouldBe None + } yield last).asserting(_ shouldBe None) } } @@ -92,15 +94,15 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { val expectedXAmzDate = `X-Amz-Date`(HttpDate.unsafeFromInstant(now)) withFixedRequest( - GET(uri"http://example.com") - .map(_.removeHeader(Date).removeHeader(`X-Amz-Date`)), + IO(GET(uri"http://example.com")) + .map(_.removeHeader[Date].removeHeader[`X-Amz-Date`]), now ) { r => IO { - r.headers.get(Date) shouldBe None - r.headers.get(`X-Amz-Date`) shouldBe Some(expectedXAmzDate) + r.headers.get[Date] shouldBe None + r.headers.get[`X-Amz-Date`] shouldBe Some(expectedXAmzDate) } - }.futureValue + } } } @@ -111,15 +113,14 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { HttpDate.unsafeFromInstant(now.minus(5, ChronoUnit.MINUTES)) ) withFixedRequest( - GET(uri"http://example.com") - .map(_.removeHeader(Date).putHeaders(expectedXAmzDate)), + IO(Request[IO](uri = uri"http://example.com", headers = Headers(expectedXAmzDate))), now ) { r => IO { - r.headers.get(Date) shouldBe None - r.headers.get(`X-Amz-Date`) shouldBe Some(expectedXAmzDate) + r.headers.get[Date] shouldBe None + r.headers.get[`X-Amz-Date`] shouldBe Some(expectedXAmzDate) } - }.futureValue + } } } } @@ -130,29 +131,28 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { val expectedDate = Date(HttpDate.unsafeFromInstant(now.minus(5, ChronoUnit.MINUTES))) withFixedRequest( - GET(uri"http://example.com") - .map(_.removeHeader(`X-Amz-Date`).putHeaders(expectedDate)), + IO(GET(uri"http://example.com")) + .map(_.removeHeader[`X-Amz-Date`].putHeaders(expectedDate)), now ) { r => IO { - r.headers.get(`X-Amz-Date`) shouldBe None - r.headers.get(Date) shouldBe Some(expectedDate) + r.headers.get[`X-Amz-Date`] shouldBe None + r.headers.get[Date] shouldBe Some(expectedDate) } - }.unsafeRunSync() + } } - } - "Host header is defined" when { + "Host header is defined" should { "the uri is absolute" should { "not add Host header" in { val expectedHost = Host("foo", 5555) withFixedRequest( - GET(uri"http://example.com") + IO(GET(uri"http://example.com")) .map(_.putHeaders(expectedHost)) ) { r => IO { - r.headers.get(Host) shouldBe Some(expectedHost) + r.headers.get[Host] shouldBe Some(expectedHost) } }.unsafeRunSync() } @@ -160,11 +160,10 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { "the uri is relative" should { "fail the effect" in { - withFixedRequest(GET(uri"/foo/bar"))(_ => IO.unit).attempt - .unsafeRunSync() shouldBe a[Left[_, _]] + withFixedRequest(IO(GET(uri"/foo/bar")))(_ => IO.unit).attempt + .asserting(_ shouldBe a[Left[_, _]]) } } - } "Credentials contain the session token" should { @@ -178,15 +177,13 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { ) val expectedXAmzSecurityToken = `X-Amz-Security-Token`(sessionToken) withFixedRequest( - GET(uri"http://example.com"), + IO(GET(uri"http://example.com")), credentials = credentials ) { r => IO { - r.headers.get(`X-Amz-Security-Token`) shouldBe Some( - expectedXAmzSecurityToken - ) + r.headers.get[`X-Amz-Security-Token`] shouldBe Some(expectedXAmzSecurityToken) } - }.futureValue + } } } @@ -196,13 +193,13 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { val credentials = Credentials(AccessKeyId("FOO"), SecretAccessKey("BAR")) withFixedRequest( - GET(uri"http://example.com"), + IO(GET(uri"http://example.com")), credentials = credentials ) { r => IO { - r.headers.get(`X-Amz-Security-Token`) shouldBe None + r.headers.get[`X-Amz-Security-Token`] shouldBe None } - }.futureValue + } } } } @@ -211,9 +208,8 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { "Date header is not defined" when { "X-Amz-Date is not defined" should { "return a failed effect" in { - withSignRequest(GET(uri"http://example.com"))(_ => IO.unit).attempt.futureValue shouldBe a[ - Left[_, _] - ] + withSignRequest(IO(GET(uri"http://example.com")))(_ => IO.unit).attempt + .asserting(_ shouldBe a[Left[_, _]]) } } } @@ -234,25 +230,28 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { .parse("20150830T123600Z", AwsSigner.dateTimeFormatter) .atZone(ZoneOffset.UTC) - val request = GET( - uri"/", - Host("example.amazonaws.com"), - `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) + val request = Request[IO]( + uri = uri"/", + headers = Headers( + Host("example.amazonaws.com"), + `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) + ) ) withSignRequest( - request, + IO(request), credentials = credentials, region = Region.`us-east-1`, service = Service("service") ) { r => IO( r.headers - .get("Authorization".ci) + .get(ci"Authorization") .get - .value shouldBe expectedAuthorizationValue + .head + .value ) - }.futureValue + }.asserting(_ shouldBe expectedAuthorizationValue) } "sign a vanilla POST request correctly" in { @@ -267,25 +266,29 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { .parse("20150830T123600Z", AwsSigner.dateTimeFormatter) .atZone(ZoneOffset.UTC) - val request = POST( - uri"/", - Host("example.amazonaws.com"), - `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) + val request = Request[IO]( + method = POST, + uri = uri"/", + headers = Headers( + Host("example.amazonaws.com"), + `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) + ) ) withSignRequest( - request, + IO(request), credentials = credentials, region = Region.`us-east-1`, service = Service("service") ) { r => IO( r.headers - .get("Authorization".ci) + .get(ci"Authorization") .get + .head .value ) - }.futureValue shouldBe expectedAuthorizationValue + }.asserting(_ shouldBe expectedAuthorizationValue) } "sign a POST request with body" in { @@ -300,15 +303,16 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { .parse("20150830T123600Z", AwsSigner.dateTimeFormatter) .atZone(ZoneOffset.UTC) - val request = POST - .apply( - "Param1=value1", - uri"/", - Host("example.amazonaws.com"), - `Content-Type`(MediaType.application.`x-www-form-urlencoded`), - `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) - ) - .map(_.removeHeader(`Content-Length`)) + val request = IO( + POST + .apply( + "Param1=value1", + uri"/", + Host("example.amazonaws.com"), + `Content-Type`(MediaType.application.`x-www-form-urlencoded`), + `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) + ) + ).map(_.removeHeader[`Content-Length`]) withSignRequest( request, @@ -318,11 +322,12 @@ class AwsSignerSpec extends UnitSpec with Http4sClientDsl[IO] { ) { r => IO( r.headers - .get("Authorization".ci) + .get(ci"Authorization") .get - .value shouldBe expectedAuthorizationValue + .head + .value ) - }.futureValue + }.asserting(_ shouldBe expectedAuthorizationValue) } } diff --git a/modules/common/src/it/scala/common/IntegrationSpec.scala b/modules/common/src/it/scala/common/IntegrationSpec.scala index b9fcada6..2c5d57c1 100644 --- a/modules/common/src/it/scala/common/IntegrationSpec.scala +++ b/modules/common/src/it/scala/common/IntegrationSpec.scala @@ -1,10 +1,10 @@ package com.ovoenergy.comms.aws package common -import org.scalatest.wordspec.AnyWordSpec +import org.scalatest.wordspec.AsyncWordSpec import org.scalatest.matchers.should.Matchers -abstract class IntegrationSpec extends AnyWordSpec with Matchers with IOFutures { +abstract class IntegrationSpec extends AsyncWordSpec with Matchers { sys.props.put("log4j.configurationFile", "log4j2-it.xml") } diff --git a/modules/common/src/main/scala/common/HttpCodecs.scala b/modules/common/src/main/scala/common/HttpCodecs.scala index ff90b0d4..7d1f79a4 100644 --- a/modules/common/src/main/scala/common/HttpCodecs.scala +++ b/modules/common/src/main/scala/common/HttpCodecs.scala @@ -34,7 +34,7 @@ trait HttpCodecs { new HttpCodec[HttpDate] { private val dateTimeFormatter = - DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'") + DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmssX") override def parse(s: String): ParseResult[HttpDate] = Try(ZonedDateTime.parse(s, dateTimeFormatter)).toEither diff --git a/modules/common/src/main/scala/common/headers.scala b/modules/common/src/main/scala/common/headers.scala index 82bacba9..cb87b1ff 100644 --- a/modules/common/src/main/scala/common/headers.scala +++ b/modules/common/src/main/scala/common/headers.scala @@ -20,27 +20,15 @@ package common import model.Credentials._ import java.time._ - -import org.http4s._ -import syntax.all._ -import Header.Raw -import util.{CaseInsensitiveString, Writer} - +import org.http4s.{Header, _} +import org.http4s.Header.Raw import cats.implicits._ +import org.typelevel.ci.{CIString, _} object headers extends HttpCodecs { - object `X-Amz-Date` extends HeaderKey.Singleton { - type HeaderT = `X-Amz-Date` - - val name: CaseInsensitiveString = "X-Amz-Date".ci - - def matchHeader(header: Header): Option[`X-Amz-Date`] = header match { - case h: `X-Amz-Date` => h.some - case Raw(n, _) if n == name => - header.parsed.asInstanceOf[`X-Amz-Date`].some - case _ => None - } + object `X-Amz-Date` { + def name: CIString = ci"X-Amz-Date" def parse(s: String): ParseResult[`X-Amz-Date`] = HttpCodec[HttpDate].parse(s).map(`X-Amz-Date`.apply) @@ -56,80 +44,87 @@ object headers extends HttpCodecs { def unsafeFromDateTime(dateTime: ZonedDateTime): `X-Amz-Date` = { `X-Amz-Date`(HttpDate.unsafeFromZonedDateTime(dateTime)) } - } - final case class `X-Amz-Date`(date: HttpDate) extends Header.Parsed { - def key: `X-Amz-Date`.type = `X-Amz-Date` + def matchHeader(header: Any): Option[`X-Amz-Date`] = header match { + case h: `X-Amz-Date` => h.some + case Raw(n, v) if n == name => parse(v).toOption + case _ => None + } - def renderValue(writer: Writer): writer.type = writer << date + implicit val dateInstance: Header[`X-Amz-Date`, Header.Single] = Header.createRendered( + `X-Amz-Date`.name, + _.date, + `X-Amz-Date`.parse + ) } - object `X-Amz-Content-SHA256` extends HeaderKey.Singleton { - type HeaderT = `X-Amz-Content-SHA256` + final case class `X-Amz-Date`(date: HttpDate) - val name: CaseInsensitiveString = "X-Amz-Content-SHA256".ci + object `X-Amz-Content-SHA256` extends { + val name: CIString = ci"X-Amz-Content-SHA256" - def matchHeader(header: Header): Option[`X-Amz-Content-SHA256`] = + def matchHeader(header: Any): Option[`X-Amz-Content-SHA256`] = header match { case h: `X-Amz-Content-SHA256` => h.some - case Raw(n, _) if n == name => - header.parsed.asInstanceOf[`X-Amz-Content-SHA256`].some + case Raw(n, v) if n == name => parse(v).toOption case _ => None } def parse(s: String): ParseResult[`X-Amz-Content-SHA256`] = `X-Amz-Content-SHA256`(s).asRight - } - final case class `X-Amz-Content-SHA256`(hashedContent: String) extends Header.Parsed { - def key: `X-Amz-Content-SHA256`.type = `X-Amz-Content-SHA256` - - def renderValue(writer: Writer): writer.type = writer << hashedContent + implicit val contentSha256Instance: Header[`X-Amz-Content-SHA256`, Header.Single] = + Header.createRendered( + `X-Amz-Content-SHA256`.name, + _.hashedContent, + `X-Amz-Content-SHA256`.parse + ) } - object `X-Amz-Security-Token` extends HeaderKey.Singleton { - type HeaderT = `X-Amz-Security-Token` + final case class `X-Amz-Content-SHA256`(hashedContent: String) - val name: CaseInsensitiveString = "X-Amz-Security-Token".ci + object `X-Amz-Security-Token` { + val name: CIString = ci"X-Amz-Security-Token" - def matchHeader(header: Header): Option[`X-Amz-Security-Token`] = + def matchHeader(header: Any): Option[`X-Amz-Security-Token`] = header match { case h: `X-Amz-Security-Token` => h.some - case Raw(n, _) if n == name => - header.parsed.asInstanceOf[`X-Amz-Security-Token`].some + case Raw(n, v) if n == name => parse(v).toOption case _ => None } def parse(s: String): ParseResult[`X-Amz-Security-Token`] = HttpCodec[SessionToken].parse(s).map(`X-Amz-Security-Token`.apply) - } - final case class `X-Amz-Security-Token`(sessionToken: SessionToken) extends Header.Parsed { - def key: `X-Amz-Security-Token`.type = `X-Amz-Security-Token` - - def renderValue(writer: Writer): writer.type = writer << sessionToken + implicit val securityTokenInstance: Header[`X-Amz-Security-Token`, Header.Single] = + Header.createRendered( + `X-Amz-Security-Token`.name, + _.sessionToken, + `X-Amz-Security-Token`.parse + ) } - object `X-Amz-Target` extends HeaderKey.Singleton { - type HeaderT = `X-Amz-Target` + final case class `X-Amz-Security-Token`(sessionToken: SessionToken) - val name: CaseInsensitiveString = "X-Amz-Target".ci + object `X-Amz-Target` { + val name: CIString = ci"X-Amz-Target" - def matchHeader(header: Header): Option[`X-Amz-Target`] = + def matchHeader(header: Any): Option[`X-Amz-Target`] = header match { case h: `X-Amz-Target` => h.some - case Raw(n, _) if n == name => - header.parsed.asInstanceOf[`X-Amz-Target`].some + case Raw(n, v) if n == name => parse(v).toOption case _ => None } def parse(s: String): ParseResult[`X-Amz-Target`] = `X-Amz-Target`(s).asRight - } - final case class `X-Amz-Target`(target: String) extends Header.Parsed { - def key: `X-Amz-Target`.type = `X-Amz-Target` - - def renderValue(writer: Writer): writer.type = writer << target + implicit val targetInstance: Header[`X-Amz-Target`, Header.Single] = Header.createRendered( + `X-Amz-Target`.name, + _.target, + `X-Amz-Target`.parse + ) } + + final case class `X-Amz-Target`(target: String) } diff --git a/modules/common/src/test/scala/common/IOFutures.scala b/modules/common/src/test/scala/common/IOFutures.scala deleted file mode 100644 index e96240cf..00000000 --- a/modules/common/src/test/scala/common/IOFutures.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2018 OVO Energy - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.ovoenergy.comms.aws -package common - -import cats.effect.IO -import org.scalatest.concurrent.Futures - -import scala.util.{Failure, Success} - -trait IOFutures extends Futures { - - implicit def convertIO[T](io: IO[T]): FutureConcept[T] = - new FutureConcept[T] { - - private val futureFromIo = io.unsafeToFuture() - - def eitherValue: Option[Either[Throwable, T]] = - futureFromIo.value.map { - case Success(o) => Right(o) - case Failure(e) => Left(e) - } - def isExpired: Boolean = - false // Scala Futures themselves don't support the notion of a timeout - def isCanceled: Boolean = - false // Scala Futures don't seem to be cancelable either - } - -} diff --git a/modules/common/src/test/scala/common/UnitSpec.scala b/modules/common/src/test/scala/common/UnitSpec.scala index d53f9ebf..a7812587 100644 --- a/modules/common/src/test/scala/common/UnitSpec.scala +++ b/modules/common/src/test/scala/common/UnitSpec.scala @@ -19,10 +19,6 @@ package common import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks import org.scalatest.matchers.should.Matchers -import org.scalatest.wordspec.AnyWordSpec +import org.scalatest.wordspec.AsyncWordSpec -abstract class UnitSpec - extends AnyWordSpec - with Matchers - with ScalaCheckDrivenPropertyChecks - with IOFutures +abstract class UnitSpec extends AsyncWordSpec with Matchers with ScalaCheckDrivenPropertyChecks diff --git a/modules/s3/src/it/scala/s3/S3Spec.scala b/modules/s3/src/it/scala/s3/S3Spec.scala index 405be10d..1c6dd0b9 100644 --- a/modules/s3/src/it/scala/s3/S3Spec.scala +++ b/modules/s3/src/it/scala/s3/S3Spec.scala @@ -3,19 +3,16 @@ package s3 import cats.implicits._ import cats.effect._ +import cats.effect.testing.scalatest.AsyncIOSpec import fs2.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.util.UUID -import java.util.concurrent.Executors - -import scala.concurrent.duration._ - import common.model._ import common.{CredentialsProvider, IntegrationSpec} import model._ -class S3Spec extends IntegrationSpec { +class S3Spec extends IntegrationSpec with AsyncIOSpec { val existingKey = Key("more.pdf") val duplicateKey = Key("duplicate") @@ -36,16 +33,6 @@ class S3Spec extends IntegrationSpec { } val randomKey = IO(Key(UUID.randomUUID().toString)) - implicit val patience: PatienceConfig = - PatienceConfig(scaled(5.seconds), 500.millis) - - val blockingEc = scala.concurrent.ExecutionContext - .fromExecutorService(Executors.newCachedThreadPool()) - val blocker = Blocker.liftExecutionContext(blockingEc) - implicit val ctx: ContextShift[IO] = - IO.contextShift(scala.concurrent.ExecutionContext.global) - implicit val timer: Timer[IO] = IO.timer(blockingEc) - "headObject" when { "the bucket exists" when { @@ -54,23 +41,21 @@ class S3Spec extends IntegrationSpec { "return the object eTag" in { withS3 { s3 => s3.headObject(existingBucket, key) - }.futureValue.map { os => - os.eTag shouldBe Etag("9fe029056e0841dde3c1b8a169635f6f") - } + }.flatMap(IO.fromEither(_)).map(os => os.eTag shouldBe Etag("9fe029056e0841dde3c1b8a169635f6f")) } "return the object metadata" in { withS3 { s3 => s3.headObject(existingBucket, key) - }.futureValue.map { os => - os.metadata shouldBe Map("is-test" -> "true") + }.flatMap { + IO.fromEither(_).map(os => os.metadata shouldBe Map("is-test" -> "true")) } } "return the object contentLength" in { withS3 { s3 => s3.headObject(existingBucket, key) - }.futureValue.map { os => + }.flatMap(IO.fromEither(_)).map { os => os.contentLength should be > 0L } } @@ -81,21 +66,19 @@ class S3Spec extends IntegrationSpec { "return a Left" in { withS3 { s3 => s3.headObject(existingBucket, notExistingKey) - }.futureValue shouldBe a[Left[_, _]] + }.unsafeRunSync() shouldBe a[Left[_, _]] } "return NoSuchKey error code" in { withS3 { s3 => s3.headObject(existingBucket, notExistingKey) - }.futureValue.left.map { error => - error.code shouldBe Error.Code("NoSuchKey") - } + }.flatMap(IO.fromEither(_)).assertThrowsError[model.Error](error => error.code shouldBe Error.Code("NoSuchKey")) } "return the given key as resource" in { withS3 { s3 => s3.headObject(existingBucket, notExistingKey) - }.futureValue.left.map { error => + }.flatMap(IO.fromEither(_)).assertThrowsError[model.Error] { error => error.key shouldBe notExistingKey.some } } @@ -107,13 +90,13 @@ class S3Spec extends IntegrationSpec { "return a Left" in { withS3 { s3 => s3.headObject(nonExistingBucket, existingKey) - }.futureValue shouldBe a[Left[_, _]] + }.unsafeRunSync() shouldBe a[Left[_, _]] } "return NoSuchBucket error code" in { withS3 { s3 => s3.headObject(nonExistingBucket, existingKey) - }.futureValue.left.map { error => + }.flatMap(IO.fromEither(_)).assertThrowsError[model.Error] { error => error.code shouldBe Error.Code("NoSuchBucket") } } @@ -121,7 +104,7 @@ class S3Spec extends IntegrationSpec { "return the given bucket" in { withS3 { s3 => s3.headObject(nonExistingBucket, existingKey) - }.futureValue.left.map { error => + }.flatMap(IO.fromEither(_)).assertThrowsError[model.Error]{ error => error.bucketName shouldBe nonExistingBucket.some } } @@ -154,7 +137,7 @@ class S3Spec extends IntegrationSpec { .map(_.leftWiden[Throwable]) .rethrow .use(_.content.compile.toList) - }.futureValue should not be empty + }.unsafeRunSync() should not be empty } // FIXME This test does not pass, but we have verified manually that the connection is getting disposed @@ -162,11 +145,9 @@ class S3Spec extends IntegrationSpec { existingBucket, existingKey ) { objOrError => - objOrError.map { obj => - (obj.content.compile.toList >> obj.content.compile.toList.attempt).futureValue shouldBe a[ - Left[_, _] - ] - } + IO.fromEither(objOrError).flatMap { obj => + obj.content.compile.toList >> obj.content.compile.toList.attempt + }.asserting(_ shouldBe a[Left[_, _]]) } } @@ -179,15 +160,15 @@ class S3Spec extends IntegrationSpec { "return NoSuchKey error code" in checkGetObject(existingBucket, notExistingKey) { objOrError => objOrError.left.map { error => - error.code shouldBe Error.Code("NoSuchKey") - } + error.code + } shouldBe Left(Error.Code("NoSuchKey")) } "return the given key as resource" in checkGetObject(existingBucket, notExistingKey) { objOrError => objOrError.left.map { error => - error.key shouldBe notExistingKey.some - } + error.key + } shouldBe Left(notExistingKey.some) } } @@ -202,14 +183,14 @@ class S3Spec extends IntegrationSpec { "return NoSuchBucket error code" in checkGetObject(nonExistingBucket, existingKey) { objOrError => objOrError.left.map { error => - error.code shouldBe Error.Code("NoSuchBucket") - } + error.code + } shouldBe Left(Error.Code("NoSuchBucket")) } "return the given bucket" in checkGetObject(nonExistingBucket, existingKey) { objOrError => objOrError.left.map { error => - error.bucketName shouldBe nonExistingBucket.some - } + error.bucketName + } shouldBe Left(nonExistingBucket.some) } } @@ -224,7 +205,7 @@ class S3Spec extends IntegrationSpec { withS3 { s3 => val contentIo: IO[ObjectContent[IO]] = moreSize.map { size => ObjectContent( - readInputStream(morePdf, chunkSize = 64 * 1024, blocker), + readInputStream(morePdf, chunkSize = 64 * 1024), size, chunked = true ) @@ -236,7 +217,7 @@ class S3Spec extends IntegrationSpec { result <- s3.putObject(existingBucket, key, content) } yield result - }.futureValue shouldBe a[Right[_, _]] + }.unsafeRunSync() shouldBe a[Right[_, _]] } "upload the object content with custom metadata" in { @@ -246,7 +227,7 @@ class S3Spec extends IntegrationSpec { withS3 { s3 => val contentIo: IO[ObjectContent[IO]] = moreSize.map { size => ObjectContent( - readInputStream(morePdf, chunkSize = 64 * 1024, blocker), + readInputStream(morePdf, chunkSize = 64 * 1024), size, chunked = true ) @@ -258,7 +239,7 @@ class S3Spec extends IntegrationSpec { _ <- s3.putObject(existingBucket, key, content, expectedMetadata) summary <- s3.headObject(existingBucket, key) } yield summary - }.futureValue.map { summary => + }.flatMap(IO.fromEither(_)).map { summary => summary.metadata shouldBe expectedMetadata } } @@ -271,7 +252,7 @@ class S3Spec extends IntegrationSpec { withS3 { s3 => s3.putObject(existingBucket, nestedKey, content) - }.futureValue shouldBe a[Right[_, _]] + }.unsafeRunSync() shouldBe a[Right[_, _]] } } @@ -283,7 +264,7 @@ class S3Spec extends IntegrationSpec { withS3 { s3 => s3.putObject(existingBucket, slashLeadingKey, content) - }.futureValue shouldBe a[Right[_, _]] + }.unsafeRunSync() shouldBe a[Right[_, _]] } } @@ -297,7 +278,7 @@ class S3Spec extends IntegrationSpec { _ <- s3.putObject(existingBucket, duplicateKey, content) result <- s3.putObject(existingBucket, duplicateKey, content) } yield result - }.futureValue shouldBe a[Right[_, _]] + }.unsafeRunSync() shouldBe a[Right[_, _]] } } } @@ -311,7 +292,7 @@ class S3Spec extends IntegrationSpec { existingKey, ObjectContent.fromByteArray(Array.fill(128 * 1026)(0: Byte)) ) - }.futureValue shouldBe a[Left[_, _]] + }.unsafeRunSync() shouldBe a[Left[_, _]] } "return NoSuchBucket error code" in withS3 { s3 => @@ -320,14 +301,16 @@ class S3Spec extends IntegrationSpec { existingKey, ObjectContent.fromByteArray(Array.fill(128 * 1026)(0: Byte)) ) - }.futureValue.left.map { error => + }.flatMap(IO.fromEither(_)).assertThrowsError[model.Error] { error => error.bucketName shouldBe nonExistingBucket.some } } } def checkGetObject[A](bucket: Bucket, key: Key)(f: Either[Error, Object[IO]] => A): A = - withS3 { _.getObject(bucket, key).use(x => IO(f(x))) }.futureValue + withS3 { + _.getObject(bucket, key).use(x => IO(f(x))) + }.unsafeRunSync() def withS3[A](f: S3[IO] => IO[A]): IO[A] = { S3.resource(CredentialsProvider.default[IO], Region.`eu-west-1`).use(f) diff --git a/modules/s3/src/main/scala/s3/S3.scala b/modules/s3/src/main/scala/s3/S3.scala index 5ab7db29..603daeca 100644 --- a/modules/s3/src/main/scala/s3/S3.scala +++ b/modules/s3/src/main/scala/s3/S3.scala @@ -19,28 +19,27 @@ package s3 import cats.implicits._ import cats.effect._ -import java.nio.ByteBuffer -import org.http4s.syntax.all._ -import org.http4s.{Service => _, headers => _, _} +import java.nio.ByteBuffer +import org.http4s._ import org.http4s.headers._ import org.http4s.Method._ import org.http4s.Header.Raw import org.http4s.client.Client -import org.http4s.client.blaze.BlazeClientBuilder +import org.http4s.blaze.client.BlazeClientBuilder import org.http4s.client.dsl.Http4sClientDsl -import scala.concurrent.ExecutionContext -import scala.concurrent.duration._ +import scala.concurrent.duration._ import scalaxml._ -import scala.xml.Elem +import scala.xml.Elem import auth.AwsSigner -import headers._ import model._ import common._ import common.model._ import fs2.text +import com.ovoenergy.comms.aws.s3.headers._ +import org.typelevel.ci.CIStringSyntax trait S3[F[_]] { @@ -59,18 +58,17 @@ trait S3[F[_]] { object S3 { - def resource[F[_]: ConcurrentEffect: Timer]( + def resource[F[_]: Async]( credentialsProvider: CredentialsProvider[F], region: Region, - endpoint: Option[Uri] = None, - ec: ExecutionContext = ExecutionContext.global + endpoint: Option[Uri] = None ): Resource[F, S3[F]] = { - BlazeClientBuilder[F](ec).resource.map(client => + BlazeClientBuilder[F].resource.map(client => S3.apply(client, credentialsProvider, region, endpoint) ) } - def apply[F[_]: Sync: Timer]( + def apply[F[_]: Async]( client: Client[F], credentialsProvider: CredentialsProvider[F], region: Region, @@ -80,6 +78,7 @@ object S3 { val signer = AwsSigner[F](credentialsProvider, region, Service.S3) val signedClient = signer(client) + def baseEndpoint(bucket: Bucket) = endpoint.getOrElse { Uri.unsafeFromString(s"https://${bucket.name}.s3.${region.value}.amazonaws.com") } @@ -93,7 +92,7 @@ object S3 { bucket: Bucket, key: Key ): Resource[F, Either[Error, Object[F]]] = - Resource.eval(GET(uri(bucket, key))).flatMap { req => + Resource.eval(Sync[F].delay(GET(uri(bucket, key)))).flatMap { req => signedClient.run(req).evalMap { case r if r.status.isSuccess => parseObjectSummary(r).value.rethrow // TODO should lack of etag be an error here? @@ -114,7 +113,7 @@ object S3 { ): F[Either[Error, ObjectSummary]] = { for { - request <- GET(uri(bucket, key)) + request <- Sync[F].delay(GET(uri(bucket, key))) result <- withRetry(signedClient.run(request).use { case r if r.status.isSuccess => r.as[ObjectSummary].map(_.asRight[Error]) @@ -140,14 +139,11 @@ object S3 { Sync[F] .fromEither(`Content-Length`.fromLong(content.contentLength)) .map { contentLength => - Headers - .of( - contentLength, - `Content-Type`(content.mediaType, content.charset) - ) - .put(metadata.map { - case (k, v) => Raw(s"${`X-Amz-Meta-`}$k".ci, v) - }.toSeq: _*) + Headers( + contentLength, + `Content-Type`(content.mediaType, content.charset), + metadata.map { case (k, v) => Raw(ci"${`X-Amz-Meta-`}$k", v) }.toSeq + ) } val extractContent: F[Array[Byte]] = @@ -168,7 +164,8 @@ object S3 { for { hs <- initHeaders contentAsSingleChunk <- extractContent - request <- PUT(contentAsSingleChunk, uri(bucket, key), hs.toList: _*) + request = Request[F](method = PUT, uri = uri(bucket, key), headers = hs) + .withEntity(contentAsSingleChunk) result <- withRetry(signedClient.run(request).use { case r if r.status.isSuccess => r.as[ObjectPut].map(_.asRight[Error]) case r if r.status.responseClass == Status.ServerError => @@ -233,7 +230,7 @@ object S3 { strict: Boolean ): DecodeResult[F, ObjectPut] = { msg.headers - .get(ETag) + .get[ETag] .map(t => ObjectPut(Etag(t.tag.tag))) // TODO InvalidMessageBodyFailure is not correct here as there is no body .fold[DecodeResult[F, ObjectPut]]( @@ -255,7 +252,7 @@ object S3 { ): DecodeResult[F, ObjectSummary] = { val eTag: DecodeResult[F, Etag] = response.headers - .get(ETag) + .get[ETag] .map(t => DecodeResult.successT[F, Etag](Etag(t.tag.tag))) .getOrElse( DecodeResult.failureT[F, Etag]( @@ -266,7 +263,7 @@ object S3 { ) val mediaType: DecodeResult[F, MediaType] = response.headers - .get(`Content-Type`) + .get[`Content-Type`] .map(_.mediaType) .fold( DecodeResult.failureT[F, MediaType]( @@ -277,12 +274,12 @@ object S3 { )(DecodeResult.successT[F, MediaType]) val charset: DecodeResult[F, Option[Charset]] = response.headers - .get(`Content-Type`) + .get[`Content-Type`] .flatMap(_.charset) .traverse(DecodeResult.successT[F, Charset]) val contentLength = response.headers - .get(`Content-Length`) + .get[`Content-Length`] .map(_.length) .fold( DecodeResult.failureT[F, Long]( @@ -292,9 +289,9 @@ object S3 { ) )(DecodeResult.successT[F, Long]) - val metadata: Map[String, String] = response.headers.toList.collect { - case h if h.name.value.toLowerCase.startsWith(`X-Amz-Meta-`) => - h.name.value.substring(`X-Amz-Meta-`.length) -> h.value + val metadata: Map[String, String] = response.headers.headers.collect { + case h if h.name.toString.toLowerCase.startsWith(`X-Amz-Meta-`) => + h.name.toString.substring(`X-Amz-Meta-`.length) -> h.value }.toMap (eTag, mediaType, charset, contentLength).mapN { (eTag, mediaType, charset, contentLength) => @@ -310,8 +307,8 @@ object S3 { } - private def fOfBodyString[F[_]: Sync: Timer](r: Response[F]) = { - r.body.through(text.utf8Decode).compile.string + private def fOfBodyString[F[_]: Sync](r: Response[F]) = { + r.body.through(text.utf8.decode).compile.string } private case class RetriableServerError(bodyContent: String) extends Exception { diff --git a/modules/s3/src/main/scala/s3/headers.scala b/modules/s3/src/main/scala/s3/headers.scala index e1882d82..f8653126 100644 --- a/modules/s3/src/main/scala/s3/headers.scala +++ b/modules/s3/src/main/scala/s3/headers.scala @@ -18,13 +18,12 @@ package com.ovoenergy.comms.aws package s3 import model._ - import org.http4s._ import syntax.all._ import Header.Raw -import util.{CaseInsensitiveString, Writer} - +import util.Writer import cats.implicits._ +import org.typelevel.ci.CIString trait HttpCodecs { @@ -49,16 +48,14 @@ object headers extends HttpCodecs { val `X-Amz-Meta-` = "x-amz-meta-" - object `X-Amz-Storage-Class` extends HeaderKey.Singleton { - type HeaderT = `X-Amz-Storage-Class` + object `X-Amz-Storage-Class` { - val name: CaseInsensitiveString = "X-Amz-Storage-Class".ci + val name: CIString = "X-Amz-Storage-Class".ci - def matchHeader(header: Header): Option[`X-Amz-Storage-Class`] = + def matchHeader(header: Any): Option[`X-Amz-Storage-Class`] = header match { case h: `X-Amz-Storage-Class` => h.some - case Raw(n, _) if n == name => - header.parsed.asInstanceOf[`X-Amz-Storage-Class`].some + case Raw(n, v) if n == name => parse(v).toOption case _ => None } @@ -67,10 +64,13 @@ object headers extends HttpCodecs { } - final case class `X-Amz-Storage-Class`(storageClass: StorageClass) extends Header.Parsed { - def key: `X-Amz-Storage-Class`.type = `X-Amz-Storage-Class` + final case class `X-Amz-Storage-Class`(storageClass: StorageClass) - def renderValue(writer: Writer): writer.type = writer << storageClass - } + implicit val storageClassInstance: Header[`X-Amz-Storage-Class`, Header.Single] = + Header.createRendered( + `X-Amz-Storage-Class`.name, + _.storageClass, + `X-Amz-Storage-Class`.parse + ) } diff --git a/modules/s3/src/main/scala/s3/model.scala b/modules/s3/src/main/scala/s3/model.scala index 4ed78ebe..685fee32 100644 --- a/modules/s3/src/main/scala/s3/model.scala +++ b/modules/s3/src/main/scala/s3/model.scala @@ -132,17 +132,14 @@ object model { charset: Option[Charset] = None ): ObjectContent[F] = ObjectContent[F]( - data = Stream.chunk(Chunk.bytes(data)).covary[F], + data = Stream.chunk(Chunk.from(data)).covary[F], contentLength = data.length.toLong, mediaType = mediaType, charset = charset, chunked = false ) - def fromPath[F[_]: Sync: ContextShift]( - path: Path, - blocker: Blocker - ): F[ObjectContent[F]] = + def fromPath[F[_]: Async](path: Path): F[ObjectContent[F]] = Sync[F] .delay(Files.size(path)) .flatTap { contentLength => @@ -158,7 +155,6 @@ object model { ObjectContent( io.file.readAll[F]( path, - blocker, ChunkSize ), contentLength, diff --git a/modules/s3/src/test/scala/s3/utils/S3UriParserSpec.scala b/modules/s3/src/test/scala/s3/utils/S3UriParserSpec.scala index 43289ce2..6df0b917 100644 --- a/modules/s3/src/test/scala/s3/utils/S3UriParserSpec.scala +++ b/modules/s3/src/test/scala/s3/utils/S3UriParserSpec.scala @@ -19,68 +19,69 @@ package s3 package utils import cats.effect.IO +import cats.effect.testing.scalatest.AsyncIOSpec import com.ovoenergy.comms.aws.common.UnitSpec import com.ovoenergy.comms.aws.s3.model._ -class S3UriParserSpec extends UnitSpec { +class S3UriParserSpec extends UnitSpec with AsyncIOSpec { "S3UriParser" should { // Bucket "parse bucket name from authority" in { val testCase = "s3://some-bucket/object-key" - S3UriParser.getBucket[IO](testCase).futureValue shouldBe Bucket("some-bucket") + S3UriParser.getBucket[IO](testCase).asserting(_ shouldBe Bucket("some-bucket")) } "parse bucket name from authority with no key" in { val testCase = "s3://some-bucket" - S3UriParser.getBucket[IO](testCase).futureValue shouldBe Bucket("some-bucket") + S3UriParser.getBucket[IO](testCase).asserting(_ shouldBe Bucket("some-bucket")) } "return fail if authority empty" in { val testCase = "s3://" - S3UriParser.getBucket[IO](testCase).attempt.futureValue shouldBe a[Left[_, _]] + S3UriParser.getBucket[IO](testCase).attempt.asserting(_ shouldBe a[Left[_, _]]) } "parse bucket name from path" in { val testCase = "https://s3.amazonaws.com/some-bucket/key" - S3UriParser.getBucket[IO](testCase).futureValue shouldBe Bucket("some-bucket") + S3UriParser.getBucket[IO](testCase).asserting(_ shouldBe Bucket("some-bucket")) } "parse bucket name from path with no key" in { val testCase = "https://s3.amazonaws.com/some-bucket" - S3UriParser.getBucket[IO](testCase).futureValue shouldBe Bucket("some-bucket") + S3UriParser.getBucket[IO](testCase).asserting(_ shouldBe Bucket("some-bucket")) } "parse bucket name from path with trailing slash" in { val testCase = "https://s3.amazonaws.com/some-bucket/" - S3UriParser.getBucket[IO](testCase).futureValue shouldBe Bucket("some-bucket") + S3UriParser.getBucket[IO](testCase).asserting(_ shouldBe Bucket("some-bucket")) } // Key "Parse key from s3 format uri" in { val testCase = "s3://some-bucket/object-key" - S3UriParser.getKey[IO](testCase).futureValue shouldBe Key("object-key") + S3UriParser.getKey[IO](testCase).asserting(_ shouldBe Key("object-key")) } "Return None from s3 format uri with no key" in { val testCase = "s3://some-bucket" - S3UriParser.getKey[IO](testCase).attempt.futureValue shouldBe a[Left[_, _]] + S3UriParser.getKey[IO](testCase).attempt.asserting(_ shouldBe a[Left[_, _]]) } "Parse key from regular format uri" in { val testCase = "https://s3.amazonaws.com/some-bucket/object-key" - S3UriParser.getKey[IO](testCase).futureValue shouldBe Key("object-key") + S3UriParser.getKey[IO](testCase).asserting(_ shouldBe Key("object-key")) } "Parse multipart key from uri" in { val testCase = "https://s3.eu-west-1.amazonaws.com/bucket/key-part-1/key-part-2/key-part-3" - S3UriParser.getKey[IO](testCase).futureValue shouldBe Key("key-part-1/key-part-2/key-part-3") + S3UriParser.getKey[IO](testCase).asserting(_ shouldBe Key("key-part-1/key-part-2/key-part-3")) } "Return None from regular format uri with no key" in { val testCase = "https://s3.amazonaws.com/some-bucket/" - S3UriParser.getKey[IO](testCase).attempt.futureValue shouldBe a[Left[_, _]] + S3UriParser.getKey[IO](testCase).attempt.asserting(_ shouldBe a[Left[_, _]]) } } diff --git a/project/plugins.sbt b/project/plugins.sbt index 3f8bf494..b7241bd5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -5,5 +5,6 @@ addSbtPlugin("de.heikoseeberger" % "sbt-header" % "5.6.0") addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.4.4") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.10.0-RC1") addSbtPlugin("fr.qux" % "sbt-release-tags-only" % "0.5.0") +addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.1") resolvers += Resolver.sonatypeRepo("releases")