diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..88ab8dc3 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,90 @@ +version: 2.1 +jobs: + build: + docker: + - image: circleci/openjdk:11-jdk + + working_directory: ~/repo + + environment: + JVM_OPTS: -Xmx3200m + TERM: dumb + + steps: + - checkout + + # Download and cache dependencies + - restore_cache: + keys: + - v1-dependencies-{{ checksum "build.sc" }} + # fallback to using the latest cache if no exact match is found + - v1-dependencies- + + # https://circleci.com/docs/2.0/env-vars/#using-bash_env-to-set-environment-variables + - run: + name: install coursier + command: | + curl -fLo cs https://git.io/coursier-cli-"$(uname | tr LD ld)" + chmod +x cs + ./cs install cs + rm cs + echo "export PATH=$PATH:/home/circleci/.local/share/coursier/bin" >> $BASH_ENV + + - run: + name: install scalafmt + command: cs install scalafmt + + - run: + name: install mill + command: | + mkdir -p ~/.local/bin + (echo "#!/usr/bin/env sh" && curl -L https://github.com/lihaoyi/mill/releases/download/0.8.0/0.8.0) > ~/.local/bin/mill + chmod +x ~/.local/bin/mill + + - run: + name: check that the code is formatted properly + command: scalafmt --test + + # For some reason if I try to separate compile and test, then the subsequent test step does nothing. + - run: + name: compile and test + command: mill __.test + + - save_cache: + paths: + - ~/.ivy2 + - ~/.cache + key: v1-dependencies--{{ checksum "build.sc" }} + + - when: + condition: + or: + - equal: [ master, << pipeline.git.branch >> ] + - equal: [ develop, << pipeline.git.branch >> ] + steps: + - run: + name: install gpg2 + # GPG in docker needs to be run with some additional flags + # and we are not able to change how mill uses it + # this is why we're creating wrapper that adds the flags + command: | + sudo apt update + sudo apt install -y gnupg2 + sudo mv /usr/bin/gpg /usr/bin/gpg-vanilla + sudo sh -c "echo '#!/bin/sh\n\n/usr/bin/gpg-vanilla --no-tty --pinentry loopback \$@' > /usr/bin/gpg" + sudo chmod 755 /usr/bin/gpg + cat /usr/bin/gpg + + - run: + name: install base64 + command: sudo apt update && sudo apt install -y cl-base64 + + - run: + name: publish + command: .circleci/publish.sh + no_output_timeout: 30m + +workflows: + build_and_publish: + jobs: + - build diff --git a/.circleci/publish.sh b/.circleci/publish.sh new file mode 100755 index 00000000..63b8197a --- /dev/null +++ b/.circleci/publish.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +set -euv + +echo $GPG_KEY | base64 --decode | gpg --batch --import + +gpg --passphrase $GPG_PASSPHRASE --batch --yes -a -b LICENSE + +if [[ "$CIRCLE_BRANCH" == "develop" ]]; then + +mill mill.scalalib.PublishModule/publishAll \ + __.publishArtifacts \ + "$OSS_USERNAME":"$OSS_PASSWORD" \ + --gpgArgs --passphrase="$GPG_PASSPHRASE",--batch,--yes,-a,-b + +elif [[ "$CIRCLE_BRANCH" == "master" ]]; then + +mill versionFile.setReleaseVersion +mill mill.scalalib.PublishModule/publishAll \ + __.publishArtifacts \ + "$OSS_USERNAME":"$OSS_PASSWORD" \ + --gpgArgs --passphrase="$GPG_PASSPHRASE",--batch,--yes,-a,-b \ + --readTimeout 600000 \ + --awaitTimeout 600000 \ + --release true + +else + + echo "Skipping publish step" + +fi diff --git a/.gitignore b/.gitignore index 9c07d4ae..8d68b3db 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,8 @@ *.class *.log +.bloop +.metals +.vscode +out/ +*.iml +/.idea* diff --git a/.mill-version b/.mill-version new file mode 100644 index 00000000..a3df0a69 --- /dev/null +++ b/.mill-version @@ -0,0 +1 @@ +0.8.0 diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 00000000..8c23f5fc --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,3 @@ +version = "2.7.4" +maxColumn = 80 +align.preset = more diff --git a/README.md b/README.md index d023585e..933a9e50 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,86 @@ -# metronome -Checkpointing PoW blockchains with HotStuff BFT +# Metronome + +Metronome is a checkpointing component for Proof-of-Work blockchains, using the [HotStuff BFT](https://arxiv.org/pdf/1803.05069.pdf) algorithm. + +## Overview +Checkpointing provides finality to blockchains by attesting to the hash of well-embedded blocks. A proper checkpointing system can secure the blockchain even against an adversary with super-majority mining power. + +The Metronome checkpointing system consists of a generic BFT Service (preferably HotStuff), a Checkpoint-assisted Blockchain, and a Checkpointing Interpreter that bridges the two. This structure enables many features, including flexible BFT choices, multi-chain support, plug-and-play forensic monitoring platform via the BFT service, and the capability of bridging trust between two different blockchains. + +### Architecture + +BFT Service: A committee-based BFT service with a simple and generic interface. It takes consensus candidates (e.g., checkpoint candidates) as input and generates certificates for the elected ones. + +Checkpoint-assisted Blockchain: Maintains the main blockchain that accepts and applies checkpointing results. The checkpointing logic is delegated to the checkpointing interpreter below. + +Checkpointing Interpreter: Maintains checkpointing logic, including the creation and validation (via blockchain) of checkpointing candidates, as well as checkpoint-related validation of new blockchain blocks. + +Each of these modules can be developed independently with only minor data structure changes required for compatibility. This independence allows flexibility with the choice of BFT algorithm (e.g., variants of OBFT or Hotstuff) and checkpointing interpreter (e.g., simple checkpoints or Advocate). + +The architecture also enables a convenient forensic monitoring module. By simply connecting to the BFT service, the forensics module can download the stream of consensus data and detect illegal behaviors such as collusion, and identify the offenders. + +![Architecture diagram](docs/architecture.png) + +![Component diagram](docs/components.png) + +### BFT Algorithm + +The BFT service delegates checkpoint proposal and candidate validation to the Checkpointing Interpreter using 2-way communication to allow asynchronous responses as and when the data becomes available. + +![Algorithm diagram](docs/master-based.png) + +When a winner is elected, a Checkpoint Certificate is compiled, comprising the checkpointed data (a block identity, or something more complex) and a witness for the BFT agreement, which proves that the decision is final and cannot be rolled back. Because of the need for this proof, low latency BFT algorithms such as HotStuff are preferred. + + +## Build + +The project is built using [Mill](https://github.com/com-lihaoyi/mill), which works fine with [Metals](https://scalameta.org/metals/docs/build-tools/mill.html). + +To compile everything, use the `__` wildcard: + +```console +mill __.compile +``` + +The project is set up to cross build to all Scala versions for downstream projects that need to import the libraries. To build any specific version, put them in square brackets: + +```console +mill metronome[2.12.10].checkpointing.app.compile +``` + +To run tests, use the wild cards again and the `.test` postix: + +```console +mill __.test +mill --watch metronome[2.13.4].rocksdb.test +``` + +To run a single test class, use the `.single` method with the full path to the spec: + +```console +mill __.storage.test.single io.iohk.metronome.storage.KVStoreStateSpec +``` + +To experiment with the code, start an interactive session: + +```console +mill -i metronome[2.13.4].hotstuff.consensus.console +``` + +### Formatting the codebase + +Please configure your editor to use `scalafmt` on save. CI will be configured to check formatting. + + +## Publishing + +We're using the [VersionFile](https://com-lihaoyi.github.io/mill/page/contrib-modules.html#version-file) plugin to manage versions. + +The initial version has been written to the file without newlines: +```console +echo -n "0.1.0-SNAPSHOT" > versionFile/version +``` + +Builds on `develop` will publish the snapshot version to Sonatype, which can be overwritten if the version number is not updated. + +During [publishing](https://com-lihaoyi.github.io/mill/page/common-project-layouts.html#publishing) on `master` we will use `mill versionFile.setReleaseVersion` to remove the `-SNAPSHOT` postfix and make a release. After that the version number should be bumped on `develop`, e.g. `mill versionFile.setNextVersion --bump minor`. diff --git a/build.sc b/build.sc new file mode 100644 index 00000000..f478b64c --- /dev/null +++ b/build.sc @@ -0,0 +1,413 @@ +import mill._ +import mill.modules._ +import scalalib._ +import ammonite.ops._ +import coursier.maven.MavenRepository +import mill.scalalib.{PublishModule, ScalaModule} +import mill.scalalib.publish.{Developer, License, PomSettings, VersionControl} +import $ivy.`com.lihaoyi::mill-contrib-versionfile:$MILL_VERSION` +import mill.contrib.versionfile.VersionFileModule + +object versionFile extends VersionFileModule + +object VersionOf { + val `better-monadic-for` = "0.3.1" + val cats = "2.3.1" + val circe = "0.12.3" + val config = "1.4.1" + val `kind-projector` = "0.11.3" + val logback = "1.2.3" + val mantis = "3.2.1-SNAPSHOT" + val monix = "3.3.0" + val prometheus = "0.10.0" + val rocksdb = "6.15.2" + val scalacheck = "1.15.2" + val scalatest = "3.2.5" + val scalanet = "0.8.0" + val shapeless = "2.3.3" + val slf4j = "1.7.30" + val `scodec-core` = "1.11.7" + val `scodec-bits` = "1.1.12" +} + +// Using 2.12.13 instead of 2.12.10 to access @nowarn, to disable certain deperaction +// warnings that come up in 2.13 but are too awkward to work around. +object metronome extends Cross[MetronomeModule]("2.12.13", "2.13.4") + +class MetronomeModule(val crossScalaVersion: String) extends CrossScalaModule { + + // Get rid of the `metronome-2.13.4-` part from the artifact name. The JAR name suffix will shows the Scala version. + // Check with `mill show metronome[2.13.4].__.artifactName` or `mill __.publishLocal`. + private def removeCrossVersion(artifactName: String): String = + "metronome-" + artifactName.split("-").drop(2).mkString("-") + + // In objects inheriting this trait, use `override def moduleDeps: Seq[PublishModule]` + // to point at other modules that also get published. In other cases such as tests + // it can be `override def moduleDeps: Seq[JavaModule]`, i.e. point at any module. + trait Publishing extends PublishModule { + def description: String + + // Make sure there's no newline in the file. + override def publishVersion = versionFile.currentVersion().toString + + override def pomSettings = PomSettings( + description = description, + organization = "io.iohk", + url = "https://github.com/input-output-hk/metronome", + licenses = Seq(License.`Apache-2.0`), + versionControl = VersionControl.github("input-output-hk", "metronome"), + // Add yourself if you make a PR! + // format: off + developers = Seq( + Developer("aakoshh", "Akosh Farkash", "https://github.com/aakoshh"), + Developer("lemastero","Piotr Paradzinski","https://github.com/lemastero"), + Developer("KonradStaniec","Konrad Staniec","https://github.com/KonradStaniec"), + Developer("rtkaczyk", "Radek Tkaczyk", "https://github.com/rtkaczyk"), + Developer("biandratti", "Maxi Biandratti", "https://github.com/biandratti") + ) + // format: on + ) + } + + /** Common properties for all Scala modules. */ + trait SubModule extends ScalaModule { + override def scalaVersion = crossScalaVersion + override def artifactName = removeCrossVersion(super.artifactName()) + + override def ivyDeps = Agg( + ivy"org.typelevel::cats-core:${VersionOf.cats}", + ivy"org.typelevel::cats-effect:${VersionOf.cats}" + ) + + override def scalacPluginIvyDeps = Agg( + ivy"com.olegpy::better-monadic-for:${VersionOf.`better-monadic-for`}" + ) + + override def repositories = super.repositories ++ Seq( + MavenRepository("https://oss.sonatype.org/content/repositories/snapshots") + ) + + override def scalacOptions = Seq( + "-unchecked", + "-deprecation", + "-feature", + "-encoding", + "utf-8", + "-Xfatal-warnings", + "-Ywarn-value-discard" + ) ++ { + crossScalaVersion.take(4) match { + case "2.12" => + // These options don't work well with 2.13 + Seq( + "-Xlint:unsound-match", + "-Ywarn-inaccessible", + "-Ywarn-unused-import", + "-Ywarn-unused:locals", + "-Ywarn-unused:patvars", + "-Ypartial-unification", // Required for the `>>` syntax. + "-language:higherKinds", + "-language:postfixOps" + ) + case "2.13" => + Seq() + } + } + + // `extends Tests` uses the context of the module in which it's defined + trait TestModule extends Tests { + override def artifactName = + removeCrossVersion(super.artifactName()) + + override def scalacOptions = + SubModule.this.scalacOptions + + override def testFrameworks = + Seq( + "org.scalatest.tools.Framework", + "org.scalacheck.ScalaCheckFramework" + ) + + // It may be useful to see logs in tests. + override def moduleDeps: Seq[JavaModule] = + super.moduleDeps ++ Seq(logging) + + // Enable logging in tests. + // Control the visibility using ./test/resources/logback.xml + // Alternatively, capture logs in memory. + override def ivyDeps = Agg( + ivy"org.scalatest::scalatest:${VersionOf.scalatest}", + ivy"org.scalacheck::scalacheck:${VersionOf.scalacheck}", + ivy"ch.qos.logback:logback-classic:${VersionOf.logback}" + ) + + def single(args: String*) = T.command { + // ScalaCheck test + if (args.headOption.exists(_.endsWith("Props"))) + super.runMain(args.head, args.tail: _*) + // ScalaTest test + else + super.runMain("org.scalatest.run", args: _*) + } + } + } + + /** Abstractions shared between all modules. */ + object core extends SubModule with Publishing { + override def description: String = + "Common abstractions." + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"com.chuusai::shapeless:${VersionOf.shapeless}", + ivy"io.monix::monix:${VersionOf.monix}" + ) + + object test extends TestModule + } + + /** Storage abstractions, e.g. a generic key-value store. */ + object storage extends SubModule { + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"org.typelevel::cats-free:${VersionOf.cats}", + ivy"org.scodec::scodec-bits:${VersionOf.`scodec-bits`}", + ivy"org.scodec::scodec-core:${VersionOf.`scodec-core`}" + ) + + object test extends TestModule + } + + /** Emit trace events, abstracting away logs and metrics. + * + * Based on https://github.com/input-output-hk/iohk-monitoring-framework/tree/master/contra-tracer + */ + object tracing extends SubModule with Publishing { + override def description: String = + "Abstractions for contravariant tracing." + + def scalacPluginIvyDeps = Agg( + ivy"org.typelevel:::kind-projector:${VersionOf.`kind-projector`}" + ) + } + + /** Additional crypto utilities such as threshold signature. */ + object crypto extends SubModule with Publishing { + override def description: String = + "Cryptographic primitives to support HotStuff and BFT proof verification." + + override def moduleDeps: Seq[PublishModule] = + Seq(core) + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"io.iohk::mantis-crypto:${VersionOf.mantis}", + ivy"org.scodec::scodec-bits:${VersionOf.`scodec-bits`}", + ivy"org.scodec::scodec-core:${VersionOf.`scodec-core`}" + ) + + object test extends TestModule + } + + /** Generic Peer-to-Peer components that can multiplex protocols + * from different modules over a single authenticated TLS connection. + */ + object networking extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq(tracing, crypto) + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"io.iohk::scalanet:${VersionOf.scalanet}" + ) + + object test extends TestModule { + override def moduleDeps: Seq[JavaModule] = + super.moduleDeps ++ Seq(logging) + } + } + + /** General configuration parser, to be used by application modules. */ + object config extends SubModule with Publishing { + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"com.typesafe:config:${VersionOf.config}", + ivy"io.circe::circe-core:${VersionOf.circe}", + ivy"io.circe::circe-parser:${VersionOf.circe}" + ) + + object test extends TestModule { + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"io.circe::circe-generic:${VersionOf.circe}" + ) + } + + override def description = "Typesafe config wrapper powered by circe" + } + + /** Generic HotStuff BFT library. */ + object hotstuff extends SubModule { + + /** Pure consensus models. */ + object consensus extends SubModule with Publishing { + override def description: String = + "Pure HotStuff consensus models." + + override def moduleDeps: Seq[PublishModule] = + Seq(core, crypto) + + object test extends TestModule + } + + /** Expose forensics events via tracing. */ + object forensics extends SubModule + + /** Implements peer-to-peer communication, state and block synchronisation. + * + * Includes the remote communication protocol messages and networking. + */ + object service extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq( + storage, + tracing, + crypto, + networking, + hotstuff.consensus, + hotstuff.forensics + ) + + object test extends TestModule { + override def moduleDeps: Seq[JavaModule] = + super.moduleDeps ++ Seq(hotstuff.consensus.test) + } + } + } + + /** Components realising the checkpointing functionality using HotStuff. */ + object checkpointing extends SubModule { + + /** Library to be included on the PoW side to validate checkpoint certificats. + * + * Includes the certificate model and the checkpoint ledger and chain models. + */ + object models extends SubModule with Publishing { + override def description: String = + "Checkpointing domain models, including the checkpoint certificate and its validation logic." + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"io.iohk::mantis-rlp:${VersionOf.mantis}" + ) + + override def moduleDeps: Seq[PublishModule] = + Seq(core, crypto, hotstuff.consensus) + + object test extends TestModule { + override def moduleDeps: Seq[JavaModule] = + super.moduleDeps ++ Seq(hotstuff.consensus.test) + } + } + + /** Library to be included on the PoW side to talk to the checkpointing service. + * + * Includes the local communication protocol messages and networking. + */ + object interpreter extends SubModule with Publishing { + override def description: String = + "Components to implement a PoW side checkpointing interpreter." + + override def ivyDeps = Agg( + ivy"io.iohk::scalanet:${VersionOf.scalanet}" + ) + + override def moduleDeps: Seq[PublishModule] = + Seq(tracing, crypto, checkpointing.models) + } + + /** Implements the checkpointing functionality, validation rules, + * state synchronisation, anything that is not an inherent part of + * HotStuff, but applies to the checkpointing use case. + * + * If it was published, it could be directly included in the checkpoint + * assisted blockchain application, so the service and the interpreter + * can share data in memory. + */ + object service extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq( + tracing, + storage, + hotstuff.service, + checkpointing.models, + checkpointing.interpreter + ) + + object test extends TestModule { + override def moduleDeps: Seq[JavaModule] = + super.moduleDeps ++ Seq( + checkpointing.models.test, + hotstuff.service.test + ) + } + } + + /** Executable application for running HotStuff and checkpointing as a stand-alone process, + * communicating with the interpreter over TCP. + */ + object app extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq( + hotstuff.service, + checkpointing.service, + rocksdb, + logging, + metrics, + config + ) + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"ch.qos.logback:logback-classic:${VersionOf.logback}", + ivy"io.iohk::scalanet-discovery:${VersionOf.scalanet}" + ) + + object test extends TestModule + } + } + + /** Implements tracing abstractions to do structured logging. + * + * To actually emit logs, a dependant module also has to add + * a dependency on e.g. logback. + */ + object logging extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq(tracing) + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"org.slf4j:slf4j-api:${VersionOf.slf4j}", + ivy"io.circe::circe-core:${VersionOf.circe}" + ) + } + + /** Implements tracing abstractions to expose metrics to Prometheus. */ + object metrics extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq(tracing) + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"io.prometheus:simpleclient:${VersionOf.prometheus}", + ivy"io.prometheus:simpleclient_httpserver:${VersionOf.prometheus}" + ) + } + + /** Implements the storage abstractions using RocksDB. */ + object rocksdb extends SubModule { + override def moduleDeps: Seq[JavaModule] = + Seq(storage) + + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"org.rocksdb:rocksdbjni:${VersionOf.rocksdb}" + ) + + object test extends TestModule { + override def ivyDeps = super.ivyDeps() ++ Agg( + ivy"io.monix::monix:${VersionOf.monix}" + ) + } + } +} diff --git a/docs/architecture.png b/docs/architecture.png new file mode 100644 index 00000000..d7317166 Binary files /dev/null and b/docs/architecture.png differ diff --git a/docs/components.png b/docs/components.png new file mode 100644 index 00000000..17df5ed8 Binary files /dev/null and b/docs/components.png differ diff --git a/docs/master-based.png b/docs/master-based.png new file mode 100644 index 00000000..d75c7583 Binary files /dev/null and b/docs/master-based.png differ diff --git a/metronome/checkpointing/interpreter/src/io/iohk/metronome/checkpointing/interpreter/messages/InterpreterMessage.scala b/metronome/checkpointing/interpreter/src/io/iohk/metronome/checkpointing/interpreter/messages/InterpreterMessage.scala new file mode 100644 index 00000000..8226fa93 --- /dev/null +++ b/metronome/checkpointing/interpreter/src/io/iohk/metronome/checkpointing/interpreter/messages/InterpreterMessage.scala @@ -0,0 +1,175 @@ +package io.iohk.metronome.checkpointing.interpreter.messages + +import io.iohk.metronome.core.messages.{RPCMessage, RPCMessageCompanion} +import io.iohk.metronome.checkpointing.models.{ + Transaction, + Ledger, + Block, + CheckpointCertificate +} + +/** Messages exchanged between the Checkpointing Service + * and the local Checkpointing Interpreter. + */ +sealed trait InterpreterMessage { self: RPCMessage => } + +object InterpreterMessage extends RPCMessageCompanion { + + /** Messages from the Service to the Interpreter. */ + sealed trait FromService + + /** Messages from the Interpreter to the Service. */ + sealed trait FromInterpreter + + /** Mark requests that require no response. */ + sealed trait NoResponse { self: Request => } + + /** The Interpreter notifies the Service about a new + * proposer block that should be added to the mempool. + * + * Only used in Advocate. + */ + case class NewProposerBlockRequest( + requestId: RequestId, + proposerBlock: Transaction.ProposerBlock + ) extends InterpreterMessage + with Request + with FromInterpreter + with NoResponse + + /** The Interpreter signals to the Service that it can + * potentially produce a new checkpoint candidate in + * the next view when the replica becomes leader. + * + * In that round, the Service should send a `CreateBlockBodyRequest`. + * + * This is a potential optimization, so we don't send the `Ledger` + * in futile attempts when there's no chance for a block to + * be produced when there have been no events. + */ + case class NewCheckpointCandidateRequest( + requestId: RequestId + ) extends InterpreterMessage + with Request + with FromInterpreter + with NoResponse + + /** When it becomes a leader of a view, the Service asks + * the Interpreter to produce a new block body, populating + * it with transactions in the correct order, based on + * the current ledger and the mempool. + * + * A response is expected even when there are no transactions + * to be put in a block, so that we can move on to the next + * leader after an idle round (agreeing on an empty block), + * without incurring a full timeout. + * + * The reason the mempool has to be sent to the Interpreter + * and not just appended to the block, with a potential + * checkpoint at the end, is because the checkpoint empties + * the Ledger, and the Service has no way of knowing whether + * all proposer blocks have been rightly checkpointed. The + * Interpreter, on the other hand, can put the checkpoint + * in the correct position in the block body, and make sure + * that proposer blocks which cannot be checkpointed yet are + * added in a trailing position. + * + * The mempool will be eventually cleared by the Service as + * blocks are executed, based on what transactions they have. + * + * Another reason the ledger and mempool are sent and not + * handled inside the Interpreter alone is because the Service + * can project the correct values based on what (potentially + * uncommitted) parent block it's currently trying to extends, + * by updating the last stable ledger and filtering the mempool + * based on the blocks in the tentative branch. The Interpreter + * doesn't have access to the block history, so it couldn't do + * the same on its own. + */ + case class CreateBlockBodyRequest( + requestId: RequestId, + ledger: Ledger, + mempool: Seq[Transaction.ProposerBlock] + ) extends InterpreterMessage + with Request + with FromService + + /** The Interpreter may or may not be able to produce a new + * checkpoint candidate, depending on whether the conditions + * are right (e.g. the next checkpointing height has been reached). + * + * The response should contain an empty block body if there is + * nothing to do, so the Service can either propose an empty block + * to keep everyone in sync, or just move to the next leader by + * other means. + */ + case class CreateBlockBodyResponse( + requestId: RequestId, + blockBody: Block.Body + ) extends InterpreterMessage + with Response + with FromInterpreter + + /** The Service asks the Interpreter to validate all transactions + * in a block, given the current ledger state. + * + * This could be done transaction by transaction, but that would + * require sending the ledger every step along the way, which + * would be less efficient. + * + * If the Interpreter doesn't have enough data to validate the + * block, it should hold on to it until it does, only responding + * when it has the final conclusion. + * + * If the transactions are valid, the Service will apply them + * on the ledger on its own; the update rules are transparent. + */ + case class ValidateBlockBodyRequest( + requestId: RequestId, + blockBody: Block.Body, + ledger: Ledger + ) extends InterpreterMessage + with Request + with FromService + + /** The Interpreter responds to the block validation request when + * it has all the data available to perform the validation. + * + * The result indicates whether the block contents were valid. + * + * Reasons for being invalid could be that a checkpoint + * was proposed which is inconsistent with the current ledger, + * or that a proposer block was pointing at an invalid block. + * + * If valid, the Service updates its copy of the ledger + * and checks that the `postStateHash` in the block also + * corresponds to its state. + */ + case class ValidateBlockBodyResponse( + requestId: RequestId, + isValid: Boolean + ) extends InterpreterMessage + with Response + with FromInterpreter + + /** The Service notifies the Interpreter about a new Checkpoint Certificate + * having been constructed, after a block had been committed that resulted + * in the commit of a checkpoint candidate. + * + * The certificate is created by the Service because it has access to all the + * block headers and quorum certificates, and thus can construct the Merkle proof. + */ + case class NewCheckpointCertificateRequest( + requestId: RequestId, + checkpointCertificate: CheckpointCertificate + ) extends InterpreterMessage + with Request + with FromService + with NoResponse + + implicit val createBlockBodyPair = + pair[CreateBlockBodyRequest, CreateBlockBodyResponse] + + implicit val validateBlockBodyPair = + pair[ValidateBlockBodyRequest, ValidateBlockBodyResponse] +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/CheckpointingAgreement.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/CheckpointingAgreement.scala new file mode 100644 index 00000000..944a8b47 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/CheckpointingAgreement.scala @@ -0,0 +1,43 @@ +package io.iohk.metronome.checkpointing + +import io.iohk.metronome.crypto +import io.iohk.metronome.hotstuff.consensus +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Secp256k1Agreement, + Signing, + VotingPhase +} +import scodec.bits.ByteVector +import io.iohk.ethereum.rlp +import io.iohk.metronome.checkpointing.models.RLPCodecs._ + +object CheckpointingAgreement extends Secp256k1Agreement { + override type Block = models.Block + override type Hash = models.Block.Header.Hash + + type GroupSignature = crypto.GroupSignature[ + PKey, + (VotingPhase, ViewNumber, Hash), + GSig + ] + + implicit val signing: Signing[CheckpointingAgreement] = + Signing.secp256k1((phase, viewNumber, hash) => + ByteVector( + rlp.encode(phase) ++ rlp.encode(viewNumber) ++ rlp.encode(hash) + ) + ) + + implicit val block: consensus.basic.Block[CheckpointingAgreement] = + new consensus.basic.Block[CheckpointingAgreement] { + override def blockHash(b: models.Block) = + b.hash + override def parentBlockHash(b: models.Block) = + b.header.parentHash + override def height(b: Block): Long = + b.header.height + override def isValid(b: models.Block) = + models.Block.isValid(b) + } +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Block.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Block.scala new file mode 100644 index 00000000..45614769 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Block.scala @@ -0,0 +1,88 @@ +package io.iohk.metronome.checkpointing.models + +import scodec.bits.ByteVector + +/** Represents what the HotStuff paper called "nodes" as the "tree", + * with the transactions in the body being the "commands". + * + * The block contents are specific to the checkpointing application. + * + * The header and body are separated because headers have to part + * of the Checkpoint Certificate; there's no need to repeat all + * the transactions there, the Merkle root will make it possible + * to prove that a given CheckpointCandidate transaction was + * indeed part of the block. The headers are needed for parent-child + * validation in the certificate as well. + */ +sealed abstract case class Block private ( + header: Block.Header, + body: Block.Body +) { + def hash: Block.Header.Hash = header.hash +} + +object Block { + type Hash = Block.Header.Hash + + /** Create a from a header and body we received from the network. + * + * It will need to be validated before it can be used, to make sure + * the header really belongs to the body. + */ + def makeUnsafe(header: Header, body: Body): Block = + new Block(header, body) {} + + /** Smart constructor for a block, setting the correct hashes in the header. */ + def make( + parent: Block, + postStateHash: Ledger.Hash, + transactions: IndexedSeq[Transaction] + ): Block = { + val body = Body(transactions) + val header = Header( + parentHash = parent.hash, + height = parent.header.height + 1, + postStateHash = postStateHash, + contentMerkleRoot = Body.contentMerkleRoot(body) + ) + makeUnsafe(header, body) + } + + /** Check that the block hashes are valid. */ + def isValid(block: Block): Boolean = + block.header.contentMerkleRoot == Body.contentMerkleRoot(block.body) + + /** The first, empty block. */ + val genesis: Block = { + val body = Body(Vector.empty) + val header = Header( + parentHash = Block.Header.Hash(ByteVector.empty), + height = 0, + postStateHash = Ledger.empty.hash, + contentMerkleRoot = MerkleTree.empty.hash + ) + makeUnsafe(header, body) + } + + case class Header( + parentHash: Header.Hash, + height: Long, + // Hash of the Ledger after executing the block. + postStateHash: Ledger.Hash, + // Merkle root of the transactions in the body. + contentMerkleRoot: MerkleTree.Hash + ) extends RLPHash[Header, Header.Hash] + + object Header extends RLPHashCompanion[Header]()(RLPCodecs.rlpBlockHeader) + + case class Body( + transactions: IndexedSeq[Transaction] + ) extends RLPHash[Body, Body.Hash] + + object Body extends RLPHashCompanion[Body]()(RLPCodecs.rlpBlockBody) { + def contentMerkleRoot(body: Body): MerkleTree.Hash = + MerkleTree + .build(body.transactions.map(tx => MerkleTree.Hash(tx.hash))) + .hash + } +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/CheckpointCertificate.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/CheckpointCertificate.scala new file mode 100644 index 00000000..dc0167cd --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/CheckpointCertificate.scala @@ -0,0 +1,54 @@ +package io.iohk.metronome.checkpointing.models + +import cats.data.NonEmptyList +import io.iohk.metronome.hotstuff.consensus.basic.QuorumCertificate +import io.iohk.metronome.checkpointing.CheckpointingAgreement +import io.iohk.metronome.checkpointing.models.Transaction.CheckpointCandidate + +/** The Checkpoint Certificate is a proof of the BFT agreement + * over a given Checkpoint Candidate. + * + * It contains the group signature over the block that the + * federation committed, together with the sequence of blocks + * from the one that originally introduced the Candidate. + * + * The interpreter can follow the parent-child relationships, + * validate the hashes and the inclusion of the Candidate in + * the original block, check the group signature, then unpack + * the contents fo the Candidate to interpet it according to + * whatever rules apply on the checkpointed PoW chain. + */ +case class CheckpointCertificate( + // `head` is the `Block.Header` that had the `CheckpointCandidate` in its `Body`. + // `last` is the `Block.Header` that has the Commit Q.C.; + headers: NonEmptyList[Block.Header], + // The opaque contents of the checkpoint that has been agreed upon. + checkpoint: Transaction.CheckpointCandidate, + // Proof that `checkpoint` is part of `headers.head.contentMerkleRoot`. + proof: MerkleTree.Proof, + // Commit Q.C. over `headers.last`. + commitQC: QuorumCertificate[CheckpointingAgreement] +) + +object CheckpointCertificate { + def construct( + block: Block, + headers: NonEmptyList[Block.Header], + commitQC: QuorumCertificate[CheckpointingAgreement] + ): Option[CheckpointCertificate] = + constructProof(block).map { case (proof, cp) => + CheckpointCertificate(headers, cp, proof, commitQC) + } + + private def constructProof( + block: Block + ): Option[(MerkleTree.Proof, CheckpointCandidate)] = + block.body.transactions.reverseIterator.collectFirst { + case cp: CheckpointCandidate => + val txHashes = + block.body.transactions.map(tx => MerkleTree.Hash(tx.hash)) + val tree = MerkleTree.build(txHashes) + val cpHash = MerkleTree.Hash(cp.hash) + MerkleTree.generateProofFromHash(tree, cpHash).map(_ -> cp) + }.flatten +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Ledger.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Ledger.scala new file mode 100644 index 00000000..3c053010 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Ledger.scala @@ -0,0 +1,40 @@ +package io.iohk.metronome.checkpointing.models + +/** Current state of the ledger after applying all previous blocks. + * + * Basically it's the last checkpoint, plus any accumulated proposer blocks + * since then. Initially the last checkpoint is empty; conceptually it could + * be the genesis block of the PoW chain, but we don't know what that is + * until we talk to the interpreter, and we also can't produce it on our + * own since it's opaque data. + */ +case class Ledger( + maybeLastCheckpoint: Option[Transaction.CheckpointCandidate], + proposerBlocks: IndexedSeq[Transaction.ProposerBlock] +) extends RLPHash[Ledger, Ledger.Hash] { + + /** Apply a validated transaction to produce the next ledger state. + * + * The transaction should have been validated against the PoW ledger + * by this point, so we know for example that the new checkpoint is + * a valid extension of the previous one. + */ + def update(transaction: Transaction): Ledger = + transaction match { + case t @ Transaction.ProposerBlock(_) => + if (proposerBlocks.contains(t)) + this + else + copy(proposerBlocks = proposerBlocks :+ t) + + case t @ Transaction.CheckpointCandidate(_) => + Ledger(Some(t), Vector.empty) + } + + def update(transactions: Iterable[Transaction]): Ledger = + transactions.foldLeft(this)(_ update _) +} + +object Ledger extends RLPHashCompanion[Ledger]()(RLPCodecs.rlpLedger) { + val empty = Ledger(None, Vector.empty) +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/MerkleTree.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/MerkleTree.scala new file mode 100644 index 00000000..624c8918 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/MerkleTree.scala @@ -0,0 +1,212 @@ +package io.iohk.metronome.checkpointing.models + +import io.iohk.metronome.core.Tagger +import io.iohk.metronome.crypto.hash.Keccak256 +import scodec.bits.ByteVector + +import scala.annotation.tailrec + +sealed trait MerkleTree { + def hash: MerkleTree.Hash +} + +object MerkleTree { + object Hash extends Tagger[ByteVector] + type Hash = Hash.Tagged + + /** MerkleTree with no elements */ + val empty = Leaf(Hash(Keccak256(ByteVector.empty))) + + private val hashFn: (Hash, Hash) => Hash = + (a, b) => Hash(Keccak256(a ++ b)) + + case class Node(hash: Hash, left: MerkleTree, right: Option[MerkleTree]) + extends MerkleTree + case class Leaf(hash: Hash) extends MerkleTree + + /** Merkle proof that some leaf content is part of the tree. + * + * It is expected that the root hash and the leaf itself is available to + * the verifier, so the proof only contains things the verifier doesn't + * know, which is the overall size of the tree and the position of the leaf + * among its siblings leaves. Based on that it is possible to use the sibling + * hash path to check whether they add up to the root hash. + * + * `leafIndex` can be interpreted as a binary number, which represents + * the path from the root of the tree down to the leaf, with the bits + * indicating whether to go left or right in each fork, while descending + * the levels. + * + * For example, take the following Merkle tree: + * ``` + * h0123 + * / \ + * h01 h23 + * / \ / \ + * h0 h1 h2 h3 + * ``` + * + * Say we want to prove that leaf 2 is part of the tree. The binary + * representation of 2 is `10`, which, we can interpret as: go right, + * then go left. + * + * The sibling path in the proof would be: `Vector(h3, h01)`. + * + * Based on this we can take the leaf value we know, reconstruct the hashes + * from the bottom to the top, and compare it agains the root hash we know: + * + * ``` + * h2 = h(value) + * h23 = h(h2, path(0)) + * h0123 = h(path(1), h23) + * assert(h0123 == root) + * ``` + * + * The right/left decisions we gleaned from the `leafIndex` tell us the order + * we have to pass the arguments to the hash function. + * + * Note that the length of binary representation of `leafIndex` corresponds + * to the height of the tree, e.g. `0010` for a tree of height 4 (9 to 16 leaves). + */ + case class Proof( + // Position of the leaf in the lowest level. + leafIndex: Int, + // Hashes of the "other" side of the tree, level by level, + // starting from the lowest up to the highest. + siblingPath: IndexedSeq[Hash] + ) + + def build(elems: Iterable[Hash]): MerkleTree = { + @tailrec + def buildTree(nodes: Seq[MerkleTree]): MerkleTree = { + if (nodes.size == 1) + nodes.head + else { + val paired = nodes.grouped(2).toSeq.map { + case Seq(a, b) => + Node(hashFn(a.hash, b.hash), a, Some(b)) + case Seq(a) => + // if the element has no pair we hash it with itself + Node(hashFn(a.hash, a.hash), a, None) + } + buildTree(paired) + } + } + + if (elems.isEmpty) + empty + else + buildTree(elems.toSeq.map(Leaf(_))) + } + + def verifyProof( + proof: Proof, + root: Hash, + leaf: Hash + ): Boolean = { + def verify(currentHash: Hash, height: Int, siblings: Seq[Hash]): Hash = { + if (siblings.isEmpty) + currentHash + else { + val goLeft = shouldTraverseLeft(height, proof.leafIndex) + val nextHash = + if (goLeft) hashFn(currentHash, siblings.head) + else hashFn(siblings.head, currentHash) + + verify(nextHash, height + 1, siblings.tail) + } + } + + verify(leaf, 1, proof.siblingPath) == root + } + + def generateProofFromIndex(root: MerkleTree, index: Int): Option[Proof] = { + if (index < 0 || index >= findSize(root)) + None + else { + val siblings = findSiblings(root, findHeight(root), index) + Some(Proof(index, siblings)) + } + } + + def generateProofFromHash(root: MerkleTree, elem: Hash): Option[Proof] = { + if (root == empty) + None + else + findElem(root, elem).map { index => + val siblings = findSiblings(root, findHeight(root), index) + Proof(index, siblings) + } + } + + @tailrec + /** Finds tree height based on leftmost branch traversal */ + private def findHeight(tree: MerkleTree, height: Int = 0): Int = tree match { + case Leaf(_) => height + case Node(_, left, _) => findHeight(left, height + 1) + } + + @tailrec + /** Finds the tree size (number of leaves), by traversing the rightmost branch */ + private def findSize(tree: MerkleTree, maxIndex: Int = 0): Int = tree match { + case `empty` => + 0 + case Leaf(_) => + maxIndex + 1 + case Node(_, left, None) => + findSize(left, maxIndex << 1) + case Node(_, _, Some(right)) => + findSize(right, maxIndex << 1 | 1) + } + + /** Looks up an element hash in the tree returning its index if it exists */ + private def findElem( + tree: MerkleTree, + elem: Hash, + index: Int = 0 + ): Option[Int] = tree match { + case Leaf(`elem`) => + Some(index) + case Leaf(_) => + None + case Node(_, left, None) => + findElem(left, elem, index << 1) + case Node(_, left, Some(right)) => + findElem(left, elem, index << 1) orElse + findElem(right, elem, index << 1 | 1) + } + + /** Traverses the tree from root towards the leaf collecting the hashes of siblings nodes. + * If a node has only one child then that child's hash is collected. Theses hashes constitute + * the Merkle proof, they are returned ordered from lowest to highest (with regard to + * the height of the tree) + */ + private def findSiblings( + tree: MerkleTree, + height: Int, + leafIndex: Int + ): IndexedSeq[Hash] = tree match { + case Leaf(_) => + Vector.empty + + case Node(_, left, None) => + if (!shouldTraverseLeft(height, leafIndex)) + Vector.empty + else + findSiblings(left, height - 1, leafIndex) :+ left.hash + + case Node(_, left, Some(right)) => + val goLeft = shouldTraverseLeft(height, leafIndex) + val (traverse, sibling) = if (goLeft) (left, right) else (right, left) + findSiblings(traverse, height - 1, leafIndex) :+ sibling.hash + } + + /** Determines tree traversal direction from a given height towards the leaf indicated + * by the index: + * + * true - traverse left child (take right hash) + * false - traverse right + */ + private def shouldTraverseLeft(height: Int, leafIndex: Int): Boolean = + (leafIndex >> (height - 1) & 1) == 0 +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/RLPCodecs.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/RLPCodecs.scala new file mode 100644 index 00000000..98d83cb5 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/RLPCodecs.scala @@ -0,0 +1,200 @@ +package io.iohk.metronome.checkpointing.models + +import cats.data.NonEmptyList +import io.iohk.ethereum.crypto.ECDSASignature +import io.iohk.ethereum.rlp.{RLPEncoder, RLPList} +import io.iohk.ethereum.rlp.RLPCodec +import io.iohk.ethereum.rlp.RLPCodec.Ops +import io.iohk.ethereum.rlp.RLPException +import io.iohk.ethereum.rlp.RLPImplicitDerivations._ +import io.iohk.ethereum.rlp.RLPImplicits._ +import io.iohk.metronome.checkpointing.CheckpointingAgreement +import io.iohk.metronome.crypto.hash.Hash +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Phase, + VotingPhase, + QuorumCertificate +} +import scodec.bits.{BitVector, ByteVector} + +object RLPCodecs { + implicit val rlpBitVector: RLPCodec[BitVector] = + implicitly[RLPCodec[Array[Byte]]].xmap(BitVector(_), _.toByteArray) + + implicit val rlpByteVector: RLPCodec[ByteVector] = + implicitly[RLPCodec[Array[Byte]]].xmap(ByteVector(_), _.toArray) + + implicit val hashRLPCodec: RLPCodec[Hash] = + implicitly[RLPCodec[ByteVector]].xmap(Hash(_), identity) + + implicit val headerHashRLPCodec: RLPCodec[Block.Header.Hash] = + implicitly[RLPCodec[ByteVector]].xmap(Block.Header.Hash(_), identity) + + implicit val bodyHashRLPCodec: RLPCodec[Block.Body.Hash] = + implicitly[RLPCodec[ByteVector]].xmap(Block.Body.Hash(_), identity) + + implicit val ledgerHashRLPCodec: RLPCodec[Ledger.Hash] = + implicitly[RLPCodec[ByteVector]].xmap(Ledger.Hash(_), identity) + + implicit val merkleHashRLPCodec: RLPCodec[MerkleTree.Hash] = + implicitly[RLPCodec[ByteVector]].xmap(MerkleTree.Hash(_), identity) + + implicit val rlpProposerBlock: RLPCodec[Transaction.ProposerBlock] = + deriveLabelledGenericRLPCodec + + implicit val rlpCheckpointCandidate + : RLPCodec[Transaction.CheckpointCandidate] = + deriveLabelledGenericRLPCodec + + implicit def rlpIndexedSeq[T: RLPCodec]: RLPCodec[IndexedSeq[T]] = + seqEncDec[T]().xmap(_.toVector, _.toSeq) + + implicit def rlpNonEmptyList[T: RLPCodec]: RLPCodec[NonEmptyList[T]] = + seqEncDec[T]().xmap( + xs => + NonEmptyList.fromList(xs.toList).getOrElse { + RLPException.decodeError("NonEmptyList", "List cannot be empty.") + }, + _.toList + ) + + implicit val rlpLedger: RLPCodec[Ledger] = + deriveLabelledGenericRLPCodec + + implicit val rlpTransaction: RLPCodec[Transaction] = { + import Transaction._ + + val ProposerBlockTag: Short = 1 + val CheckpointCandidateTag: Short = 2 + + def encodeWithTag[T: RLPEncoder](tag: Short, value: T) = { + val t = RLPEncoder.encode(tag) + val l = RLPEncoder.encode(value).asInstanceOf[RLPList] + t +: l + } + + RLPCodec.instance[Transaction]( + { + case tx: ProposerBlock => + encodeWithTag(ProposerBlockTag, tx) + case tx: CheckpointCandidate => + encodeWithTag(CheckpointCandidateTag, tx) + }, + { case RLPList(tag, items @ _*) => + val rest = RLPList(items: _*) + tag.decodeAs[Short]("tag") match { + case ProposerBlockTag => + rest.decodeAs[ProposerBlock]("transaction") + case CheckpointCandidateTag => + rest.decodeAs[CheckpointCandidate]("transaction") + case unknown => + RLPException.decodeError( + "Transaction", + s"Unknown tag: $unknown", + List(tag) + ) + } + } + ) + } + + implicit val rlpBlockBody: RLPCodec[Block.Body] = + deriveLabelledGenericRLPCodec + + implicit val rlpBlockHeader: RLPCodec[Block.Header] = + deriveLabelledGenericRLPCodec + + // Cannot use derivation because Block is a sealed abstract case class, + // so it doesn't allow creation of an invalid block. + implicit val rlpBlock: RLPCodec[Block] = + RLPCodec.instance[Block]( + block => + RLPList( + RLPEncoder.encode(block.header), + RLPEncoder.encode(block.body) + ), + { case RLPList(header, body) => + val h = header.decodeAs[Block.Header]("header") + val b = body.decodeAs[Block.Body]("body") + Block.makeUnsafe(h, b) + } + ) + + implicit val rlpMerkleProof: RLPCodec[MerkleTree.Proof] = + deriveLabelledGenericRLPCodec + + implicit val rlpViewNumber: RLPCodec[ViewNumber] = + implicitly[RLPCodec[Long]].xmap(ViewNumber(_), identity) + + implicit val rlpVotingPhase: RLPCodec[VotingPhase] = + RLPCodec.instance[VotingPhase]( + phase => { + val tag: Short = phase match { + case Phase.Prepare => 1 + case Phase.PreCommit => 2 + case Phase.Commit => 3 + } + RLPEncoder.encode(tag) + }, + { case tag => + tag.decodeAs[Short]("phase") match { + case 1 => Phase.Prepare + case 2 => Phase.PreCommit + case 3 => Phase.Commit + case u => + RLPException.decodeError( + "VotingPhase", + s"Unknown phase tag: $u", + List(tag) + ) + } + } + ) + + implicit val rlpECDSASignature: RLPCodec[ECDSASignature] = + RLPCodec.instance[ECDSASignature]( + sig => RLPEncoder.encode(sig.toBytes), + { case enc => + val bytes = enc.decodeAs[ByteVector]("signature") + ECDSASignature + .fromBytes(akka.util.ByteString.fromArrayUnsafe(bytes.toArray)) + .getOrElse { + RLPException.decodeError( + "ECDSASignature", + "Invalid signature format.", + List(enc) + ) + } + } + ) + + implicit val rlpGroupSignature + : RLPCodec[CheckpointingAgreement.GroupSignature] = + deriveLabelledGenericRLPCodec + + // Derviation doesn't seem to work on generic case class. + implicit val rlpQuorumCertificate + : RLPCodec[QuorumCertificate[CheckpointingAgreement]] = + RLPCodec.instance[QuorumCertificate[CheckpointingAgreement]]( + { case QuorumCertificate(phase, viewNumber, blockHash, signature) => + RLPList( + RLPEncoder.encode(phase), + RLPEncoder.encode(viewNumber), + RLPEncoder.encode(blockHash), + RLPEncoder.encode(signature) + ) + }, + { case RLPList(phase, viewNumber, blockHash, signature) => + QuorumCertificate[CheckpointingAgreement]( + phase.decodeAs[VotingPhase]("phase"), + viewNumber.decodeAs[ViewNumber]("viewNumber"), + blockHash.decodeAs[CheckpointingAgreement.Hash]("blockHash"), + signature.decodeAs[CheckpointingAgreement.GroupSignature]("signature") + ) + } + ) + + implicit val rlpCheckpointCertificate: RLPCodec[CheckpointCertificate] = + deriveLabelledGenericRLPCodec +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/RLPHash.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/RLPHash.scala new file mode 100644 index 00000000..785da739 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/RLPHash.scala @@ -0,0 +1,45 @@ +package io.iohk.metronome.checkpointing.models + +import io.iohk.ethereum.rlp +import io.iohk.ethereum.rlp.RLPEncoder +import io.iohk.metronome.crypto +import io.iohk.metronome.crypto.hash.Keccak256 +import io.iohk.metronome.core.Tagger +import scodec.bits.ByteVector +import scala.language.implicitConversions + +/** Type class to produce a specific type of hash based on the RLP + * representation of a type, where the hash type is typically + * defined in the companion object of the type. + */ +trait RLPHasher[T] { + type Hash + def hash(value: T): Hash +} +object RLPHasher { + type Aux[T, H] = RLPHasher[T] { + type Hash = H + } +} + +/** Base class for types that have a hash value based on their RLP representation. */ +abstract class RLPHash[T, H](implicit ev: RLPHasher.Aux[T, H]) { self: T => + lazy val hash: H = ev.hash(self) +} + +/** Base class for companion objects for types that need hashes based on RLP. + * + * Every companion will define a separate `Hash` type, so we don't mix them up. + */ +abstract class RLPHashCompanion[T: RLPEncoder] extends RLPHasher[T] { self => + object Hash extends Tagger[ByteVector] + override type Hash = Hash.Tagged + + override def hash(value: T): Hash = + Hash(Keccak256(rlp.encode(value))) + + implicit val hasher: RLPHasher.Aux[T, Hash] = this + + implicit def `Hash => crypto.Hash`(h: Hash): crypto.hash.Hash = + crypto.hash.Hash(h) +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Transaction.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Transaction.scala new file mode 100644 index 00000000..1d6ee6dc --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/models/Transaction.scala @@ -0,0 +1,55 @@ +package io.iohk.metronome.checkpointing.models + +import scodec.bits.BitVector + +/** Transactions are what comprise the block body used by the Checkpointing Service. + * + * The HotStuff BFT Agreement doesn't need to know about them, their execution and + * validation is delegated to the Checkpointing Service, which, in turn, delegates + * to the interpreter. The only component that truly has to understand the contents + * is the PoW specific interpreter. + * + * What the Checkpointing Service has to know is the different kinds of transactions + * we support, which is to register proposer blocks in the ledger, required by Advocate, + * and to register checkpoint candidates. + */ +sealed trait Transaction extends RLPHash[Transaction, Transaction.Hash] + +object Transaction + extends RLPHashCompanion[Transaction]()(RLPCodecs.rlpTransaction) { + + /** In PoW chains that support Advocate checkpointing, the Checkpoint Certificate + * can enforce the inclusion of proposed blocks on the chain via references; think + * uncle blocks that also get executed. + * + * In order to know which proposed blocks can be enforced, i.e. ones that are valid + * and have saturated the network, first the federation members need to reach BFT + * agreement over the list of existing proposer blocks. + * + * The `ProposerBlock` transaction adds one of these blocks that exist on the PoW + * chain to the Checkpointing Ledger, iff it can be validated by the members. + * + * The contents of the transaction are opaque, they only need to be understood + * by the PoW side interpreter. + * + * Using Advocate is optional; if the PoW chain doesn't support references, + * it will just use `CheckpointCandidate` transactions. + */ + case class ProposerBlock(value: BitVector) extends Transaction + + /** When a federation member is leading a round, it will ask the PoW side interpreter + * if it wants to propose a checkpoint candidate. The interpreter decides if the + * circumstances are right, e.g. enough new blocks have been build on the previous + * checkpoint that a new one has to be issued. If so, a `CheckpointCandidate` + * transaction is added to the next block, which is sent to the HotStuff replicas + * in a `Prepare` message, to be validated and committed. + * + * If the BFT agreement is successful, a Checkpoint Certificate will be formed + * during block execution which will include the `CheckpointCandidate`. + * + * The contents of the transaction are opaque, they only need to be understood + * by teh PoW side interpreter, either for validation, or for following the + * fork indicated by the checkpoint. + */ + case class CheckpointCandidate(value: BitVector) extends Transaction +} diff --git a/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/package.scala b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/package.scala new file mode 100644 index 00000000..e9a57fb3 --- /dev/null +++ b/metronome/checkpointing/models/src/io/iohk/metronome/checkpointing/package.scala @@ -0,0 +1,5 @@ +package io.iohk.metronome + +package object checkpointing { + type CheckpointingAgreement = CheckpointingAgreement.type +} diff --git a/metronome/checkpointing/models/test/resources/golden/RLPCodec[CheckpointCertificate].rlp b/metronome/checkpointing/models/test/resources/golden/RLPCodec[CheckpointCertificate].rlp new file mode 100644 index 00000000..25e01a48 --- /dev/null +++ b/metronome/checkpointing/models/test/resources/golden/RLPCodec[CheckpointCertificate].rlp @@ -0,0 +1 @@ +f901d7f8ccf864a07f80dddfff7f3a0b809480dd53010196127f7f7f01270027ff00803651c4fc7f64a0ffff597f000000007fd4486a017f80a57f5cd17f0000ffd41032f77f0080d77fa0ff51807f7f7fc33d7f00d4ff017f7fd5ff008e01007f94017ffb57800174d9fff864a080ff63ff9a798000177f9ab900419d80345eff00007f000144d7e67fffef800020a0147fa3550080ff6c007fd59028ba1b7f4313eed26a7f52c4ad4dd28d804a289ba001a90100ff0096fe4c7f800199747f80ff01ff8085176c00146d8064fff58001d2917f707f808000f7ff5380617d5201005b7ff84502f842a0800380b0c46000ff2c805bfba901c90104011f99241801477e1fa70b8000a101a080ff1315ea810000ee00000db0ed007fff00a79d0000ffb400ffff274f012531f8ad030aa001ffcdffffff0d157f347f018dd1ec012180ea011900a001379f7f01800e3715f888f886b841000000000000000000000000000000000000000000000000000000000000005e00000000000000000000000000000000000000000000000000000000000000111cb841000000000000000000000000000000000000000000000000000000000000002700000000000000000000000000000000000000000000000000000000000000461c \ No newline at end of file diff --git a/metronome/checkpointing/models/test/resources/golden/RLPCodec[CheckpointCertificate].txt b/metronome/checkpointing/models/test/resources/golden/RLPCodec[CheckpointCertificate].txt new file mode 100644 index 00000000..16bb10ac --- /dev/null +++ b/metronome/checkpointing/models/test/resources/golden/RLPCodec[CheckpointCertificate].txt @@ -0,0 +1 @@ +CheckpointCertificate(NonEmptyList(Header(ByteVector(32 bytes, 0x7f80dddfff7f3a0b809480dd53010196127f7f7f01270027ff00803651c4fc7f),100,ByteVector(32 bytes, 0xffff597f000000007fd4486a017f80a57f5cd17f0000ffd41032f77f0080d77f),ByteVector(32 bytes, 0xff51807f7f7fc33d7f00d4ff017f7fd5ff008e01007f94017ffb57800174d9ff)), Header(ByteVector(32 bytes, 0x80ff63ff9a798000177f9ab900419d80345eff00007f000144d7e67fffef8000),32,ByteVector(32 bytes, 0x147fa3550080ff6c007fd59028ba1b7f4313eed26a7f52c4ad4dd28d804a289b),ByteVector(32 bytes, 0x01a90100ff0096fe4c7f800199747f80ff01ff8085176c00146d8064fff58001))),CheckpointCandidate(BitVector(136 bits, 0x7f707f808000f7ff5380617d5201005b7f)),Proof(2,Vector(ByteVector(32 bytes, 0x800380b0c46000ff2c805bfba901c90104011f99241801477e1fa70b8000a101), ByteVector(32 bytes, 0x80ff1315ea810000ee00000db0ed007fff00a79d0000ffb400ffff274f012531))),QuorumCertificate(Commit,10,ByteVector(32 bytes, 0x01ffcdffffff0d157f347f018dd1ec012180ea011900a001379f7f01800e3715),GroupSignature(List(ECDSASignature(94,17,28), ECDSASignature(39,70,28))))) \ No newline at end of file diff --git a/metronome/checkpointing/models/test/resources/golden/RLPCodec[Ledger].rlp b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Ledger].rlp new file mode 100644 index 00000000..0d950b6d --- /dev/null +++ b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Ledger].rlp @@ -0,0 +1 @@ +f874e9e8a780a2017fe6000001f6ff6562991fa96676ab000100c6eaff7fb080d1017f4900047f00fbb1ff17f848de9d80d80100567f8f7f4d00ff27843963ffff7aff7f4101ff7f00ffff8001e8a700cb05f2ffff2dd91fff57446e803f3001d7cf80e3b5007f7601ff0708808001e000a0ff6e8057 \ No newline at end of file diff --git a/metronome/checkpointing/models/test/resources/golden/RLPCodec[Ledger].txt b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Ledger].txt new file mode 100644 index 00000000..0fb67def --- /dev/null +++ b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Ledger].txt @@ -0,0 +1 @@ +Ledger(Some(CheckpointCandidate(BitVector(312 bits, 0x80a2017fe6000001f6ff6562991fa96676ab000100c6eaff7fb080d1017f4900047f00fbb1ff17))),Vector(ProposerBlock(BitVector(232 bits, 0x80d80100567f8f7f4d00ff27843963ffff7aff7f4101ff7f00ffff8001)), ProposerBlock(BitVector(312 bits, 0x00cb05f2ffff2dd91fff57446e803f3001d7cf80e3b5007f7601ff0708808001e000a0ff6e8057)))) \ No newline at end of file diff --git a/metronome/checkpointing/models/test/resources/golden/RLPCodec[Transaction].rlp b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Transaction].rlp new file mode 100644 index 00000000..0a293a78 --- /dev/null +++ b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Transaction].rlp @@ -0,0 +1 @@ +ef01adcbff7913ff0000ac1a01009bb245579601b680016500cf02597f070080c318000004ad002faa27b58001ea7f00 \ No newline at end of file diff --git a/metronome/checkpointing/models/test/resources/golden/RLPCodec[Transaction].txt b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Transaction].txt new file mode 100644 index 00000000..66fadbbb --- /dev/null +++ b/metronome/checkpointing/models/test/resources/golden/RLPCodec[Transaction].txt @@ -0,0 +1 @@ +ProposerBlock(BitVector(360 bits, 0xcbff7913ff0000ac1a01009bb245579601b680016500cf02597f070080c318000004ad002faa27b58001ea7f00)) \ No newline at end of file diff --git a/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/CheckpointSigningSpec.scala b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/CheckpointSigningSpec.scala new file mode 100644 index 00000000..0dd90099 --- /dev/null +++ b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/CheckpointSigningSpec.scala @@ -0,0 +1,44 @@ +package io.iohk.metronome.checkpointing + +import io.iohk.metronome.crypto.ECKeyPair +import io.iohk.metronome.hotstuff.consensus.basic.{Signing, VotingPhase} +import io.iohk.metronome.hotstuff.consensus.{ + Federation, + LeaderSelection, + ViewNumber +} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.security.SecureRandom + +/** A single positive case spec to test type interoperability. + * See [[io.iohk.metronome.hotstuff.consensus.basic.Secp256k1SigningProps]] for a more in-depth test + */ +class CheckpointSigningSpec extends AnyFlatSpec with Matchers { + import models.ArbitraryInstances._ + + "Checkpoint signing" should "work :)" in { + val keyPairs = IndexedSeq.fill(2)(ECKeyPair.generate(new SecureRandom)) + val federation = Federation(keyPairs.map(_.pub))(LeaderSelection.RoundRobin) + .getOrElse(throw new Exception("Could not build federation")) + + val signing = implicitly[Signing[CheckpointingAgreement]] + + val phase = sample[VotingPhase] + val viewNumber = sample[ViewNumber] + val hash = sample[CheckpointingAgreement.Hash] + + val partialSigs = + keyPairs.map(kp => signing.sign(kp.prv, phase, viewNumber, hash)) + val groupSig = signing.combine(partialSigs) + + signing.validate( + federation, + groupSig, + phase, + viewNumber, + hash + ) shouldBe true + } +} diff --git a/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/ArbitraryInstances.scala b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/ArbitraryInstances.scala new file mode 100644 index 00000000..f4e092ea --- /dev/null +++ b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/ArbitraryInstances.scala @@ -0,0 +1,149 @@ +package io.iohk.metronome.checkpointing.models + +import cats.data.NonEmptyList +import io.iohk.ethereum.crypto.ECDSASignature +import io.iohk.metronome.checkpointing.CheckpointingAgreement +import io.iohk.metronome.crypto.hash.Hash +import io.iohk.metronome.hotstuff.consensus.basic.{ + Phase, + QuorumCertificate, + VotingPhase +} +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import org.scalacheck._ +import org.scalacheck.Arbitrary.arbitrary +import scodec.bits.BitVector +import io.iohk.metronome.crypto.GroupSignature + +object ArbitraryInstances + extends io.iohk.metronome.hotstuff.consensus.ArbitraryInstances { + implicit val arbBitVector: Arbitrary[BitVector] = + Arbitrary { + for { + // Choose a size that BitVector still renders as hex in toString, + // so the exact value is easy to see in test output or golden files. + // Over that it renders hashCode which can differ between Scala versions. + n <- Gen.choose(0, 64) + bs <- Gen.listOfN(n, arbitrary[Byte]) + } yield BitVector(bs.toArray) + } + + implicit val arbHeaderHash: Arbitrary[Block.Header.Hash] = + Arbitrary(arbitrary[Hash].map(Block.Header.Hash(_))) + + implicit val arbBodyHash: Arbitrary[Block.Body.Hash] = + Arbitrary(arbitrary[Hash].map(Block.Body.Hash(_))) + + implicit val arbLedgerHash: Arbitrary[Ledger.Hash] = + Arbitrary(arbitrary[Hash].map(Ledger.Hash(_))) + + implicit val arbMerkleHash: Arbitrary[MerkleTree.Hash] = + Arbitrary(arbitrary[Hash].map(MerkleTree.Hash(_))) + + implicit val arbProposerBlock: Arbitrary[Transaction.ProposerBlock] = + Arbitrary { + arbitrary[BitVector].map(Transaction.ProposerBlock(_)) + } + + implicit val arbCheckpointCandidate + : Arbitrary[Transaction.CheckpointCandidate] = + Arbitrary { + arbitrary[BitVector].map(Transaction.CheckpointCandidate(_)) + } + + implicit val arbTransaction: Arbitrary[Transaction] = + Arbitrary { + Gen.frequency( + 4 -> arbitrary[Transaction.ProposerBlock], + 1 -> arbitrary[Transaction.CheckpointCandidate] + ) + } + + implicit val arbLedger: Arbitrary[Ledger] = + Arbitrary { + for { + mcp <- arbitrary[Option[Transaction.CheckpointCandidate]] + pbs <- arbitrary[Set[Transaction.ProposerBlock]].map(_.toVector) + } yield Ledger(mcp, pbs) + } + + implicit val arbBlock: Arbitrary[Block] = + Arbitrary { + for { + parentHash <- arbitrary[Block.Header.Hash] + height <- Gen.posNum[Long] + postStateHash <- arbitrary[Ledger.Hash] + transactions <- arbitrary[Vector[Transaction]] + contentMerkleRoot <- arbitrary[MerkleTree.Hash] + body = Block.Body(transactions) + header = Block.Header( + parentHash, + height, + postStateHash, + contentMerkleRoot + ) + } yield Block.makeUnsafe(header, body) + } + + implicit val arbBlockHeader: Arbitrary[Block.Header] = + Arbitrary(arbitrary[Block].map(_.header)) + + implicit val arbECDSASignature: Arbitrary[ECDSASignature] = + Arbitrary { + for { + r <- Gen.posNum[BigInt] + s <- Gen.posNum[BigInt] + v <- Gen.oneOf( + ECDSASignature.positivePointSign, + ECDSASignature.negativePointSign + ) + } yield ECDSASignature(r, s, v) + } + + implicit val arbQuorumCertificate + : Arbitrary[QuorumCertificate[CheckpointingAgreement]] = + Arbitrary { + for { + phase <- arbitrary[VotingPhase] + viewNumber <- arbitrary[ViewNumber] + blockHash <- arbitrary[Block.Header.Hash] + signature <- arbitrary[CheckpointingAgreement.GSig] + } yield QuorumCertificate[CheckpointingAgreement]( + phase, + viewNumber, + blockHash, + GroupSignature(signature) + ) + } + + implicit val arbCheckpointCertificate: Arbitrary[CheckpointCertificate] = + Arbitrary { + for { + n <- Gen.posNum[Int] + headers <- Gen + .listOfN(n, arbitrary[Block.Header]) + .map(NonEmptyList.fromListUnsafe(_)) + + checkpoint <- arbitrary[Transaction.CheckpointCandidate] + + leafIndex <- Gen.choose(0, 10) + siblings <- arbitrary[Vector[MerkleTree.Hash]] + proof = MerkleTree.Proof(leafIndex, siblings) + + viewNumber <- Gen.posNum[Long].map(x => ViewNumber(x + n)) + signature <- arbitrary[CheckpointingAgreement.GSig] + commitQC = QuorumCertificate[CheckpointingAgreement]( + phase = Phase.Commit, + viewNumber = viewNumber, + blockHash = headers.head.hash, + signature = GroupSignature(signature) + ) + + } yield CheckpointCertificate( + headers, + checkpoint, + proof, + commitQC + ) + } +} diff --git a/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/LedgerProps.scala b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/LedgerProps.scala new file mode 100644 index 00000000..f32ef0a5 --- /dev/null +++ b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/LedgerProps.scala @@ -0,0 +1,32 @@ +package io.iohk.metronome.checkpointing.models + +import org.scalacheck._ +import org.scalacheck.Prop.forAll + +object LedgerProps extends Properties("Ledger") { + import ArbitraryInstances._ + + property("update") = forAll { (ledger: Ledger, transaction: Transaction) => + val updated = ledger.update(transaction) + + transaction match { + case _: Transaction.ProposerBlock + if ledger.proposerBlocks.contains(transaction) => + updated == ledger + + case _: Transaction.ProposerBlock => + updated.proposerBlocks.last == transaction && + updated.proposerBlocks.distinct == updated.proposerBlocks && + updated.maybeLastCheckpoint == ledger.maybeLastCheckpoint + + case _: Transaction.CheckpointCandidate => + updated.maybeLastCheckpoint.contains(transaction) && + updated.proposerBlocks.isEmpty + } + } + + property("hash") = forAll { (ledger1: Ledger, ledger2: Ledger) => + ledger1 == ledger2 && ledger1.hash == ledger2.hash || + ledger1 != ledger2 && ledger1.hash != ledger2.hash + } +} diff --git a/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/MerkleTreeProps.scala b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/MerkleTreeProps.scala new file mode 100644 index 00000000..55c83e64 --- /dev/null +++ b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/MerkleTreeProps.scala @@ -0,0 +1,65 @@ +package io.iohk.metronome.checkpointing.models + +import org.scalacheck.{Gen, Properties} +import org.scalacheck.Arbitrary.arbitrary +import ArbitraryInstances.arbMerkleHash +import org.scalacheck.Prop.forAll + +object MerkleTreeProps extends Properties("MerkleTree") { + + def genElements(max: Int = 256): Gen[List[MerkleTree.Hash]] = + Gen.choose(0, max).flatMap { n => + Gen.listOfN(n, arbitrary(arbMerkleHash)) + } + + property("inclusionProof") = forAll(genElements()) { elements => + val merkleTree = MerkleTree.build(elements) + elements.zipWithIndex.forall { case (elem, idx) => + val fromHash = MerkleTree.generateProofFromHash(merkleTree, elem) + val fromIndex = MerkleTree.generateProofFromIndex(merkleTree, idx) + fromHash == fromIndex && fromHash.isDefined + } + } + + property("proofVerification") = forAll(genElements()) { elements => + val merkleTree = MerkleTree.build(elements) + elements.forall { elem => + val maybeProof = MerkleTree.generateProofFromHash(merkleTree, elem) + maybeProof.exists(MerkleTree.verifyProof(_, merkleTree.hash, elem)) + } + } + + property("noFalseInclusion") = forAll(genElements(128), genElements(32)) { + (elements, other) => + val nonElements = other.diff(elements) + val merkleTree = MerkleTree.build(elements) + + val noFalseProof = nonElements.forall { nonElem => + MerkleTree.generateProofFromHash(merkleTree, nonElem).isEmpty + } + + val noFalseVerification = elements.forall { elem => + val proof = MerkleTree.generateProofFromHash(merkleTree, elem).get + !nonElements.exists(MerkleTree.verifyProof(proof, merkleTree.hash, _)) + } + + noFalseProof && noFalseVerification + } + + property("emptyTree") = { + val empty = MerkleTree.build(Nil) + + MerkleTree.generateProofFromHash(empty, MerkleTree.empty.hash).isEmpty && + empty.hash == MerkleTree.empty.hash + } + + property("singleElementTree") = forAll(arbMerkleHash.arbitrary) { elem => + val tree = MerkleTree.build(elem :: Nil) + + tree.hash == elem && + MerkleTree + .generateProofFromHash(tree, elem) + .map(MerkleTree.verifyProof(_, tree.hash, elem)) + .contains(true) + } +} diff --git a/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/RLPCodecsProps.scala b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/RLPCodecsProps.scala new file mode 100644 index 00000000..5b8d6d82 --- /dev/null +++ b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/RLPCodecsProps.scala @@ -0,0 +1,28 @@ +package io.iohk.metronome.checkpointing.models + +import io.iohk.ethereum.crypto.ECDSASignature +import io.iohk.ethereum.rlp +import io.iohk.ethereum.rlp.RLPCodec +import org.scalacheck._ +import org.scalacheck.Prop.forAll +import scala.reflect.ClassTag + +object RLPCodecsProps extends Properties("RLPCodecs") { + import ArbitraryInstances._ + import RLPCodecs._ + + /** Test that encoding to and decoding from RLP preserves the value. */ + def propRoundTrip[T: RLPCodec: Arbitrary: ClassTag] = + property(implicitly[ClassTag[T]].runtimeClass.getSimpleName) = forAll { + (value0: T) => + val bytes = rlp.encode(value0) + val value1 = rlp.decode[T](bytes) + value0 == value1 + } + + propRoundTrip[Ledger] + propRoundTrip[Transaction] + propRoundTrip[Block] + propRoundTrip[ECDSASignature] + propRoundTrip[CheckpointCertificate] +} diff --git a/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/RLPCodecsSpec.scala b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/RLPCodecsSpec.scala new file mode 100644 index 00000000..ca444359 --- /dev/null +++ b/metronome/checkpointing/models/test/src/io/iohk/metronome/checkpointing/models/RLPCodecsSpec.scala @@ -0,0 +1,236 @@ +package io.iohk.metronome.checkpointing.models + +import cats.data.NonEmptyList +import io.iohk.ethereum.crypto.ECDSASignature +import io.iohk.ethereum.rlp._ +import io.iohk.ethereum.rlp +import io.iohk.metronome.crypto.GroupSignature +import io.iohk.metronome.checkpointing.CheckpointingAgreement +import io.iohk.metronome.hotstuff.consensus.basic.{Phase, QuorumCertificate} +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import java.nio.file.{Files, Path, StandardOpenOption} +import org.scalactic.Equality +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import scala.reflect.ClassTag +import scodec.bits.BitVector + +/** Concrete examples of RLP encoding, so we can make sure the structure is what we expect. + * + * Complements `RLPCodecsProps` which works with arbitrary data. + */ +class RLPCodecsSpec extends AnyFlatSpec with Matchers { + import ArbitraryInstances._ + import RLPCodecs._ + + // Structrual equality checker for RLPEncodeable. + // It has different wrappers for items based on whether it was hand crafted or generated + // by codecs, and the RLPValue has mutable arrays inside. + implicit val eqRLPList = new Equality[RLPEncodeable] { + override def areEqual(a: RLPEncodeable, b: Any): Boolean = + (a, b) match { + case (a: RLPList, b: RLPList) => + a.items.size == b.items.size && a.items.zip(b.items).forall { + case (a, b) => + areEqual(a, b) + } + case (a: RLPValue, b: RLPValue) => + a.bytes.sameElements(b.bytes) + case _ => + false + } + } + + abstract class Example[T: RLPCodec: ClassTag] { + def decoded: T + def encoded: RLPEncodeable + + def name = + s"RLPCodec[${implicitly[ClassTag[T]].runtimeClass.getSimpleName}]" + + def encode: RLPEncodeable = RLPEncoder.encode(decoded) + def decode: T = RLPDecoder.decode[T](encoded) + + def decode(bytes: BitVector): T = + rlp.decode[T](bytes.toByteArray) + } + + def exampleBehavior[T](example: Example[T]) = { + it should "encode the example value to the expected RLP data" in { + example.encode shouldEqual example.encoded + } + + it should "decode the example RLP data to the expected value" in { + example.decode shouldEqual example.decoded + } + } + + /** When the example is first executed, create a golden file that we can + * check in with the code for future reference, and to detect any regression. + * + * If there are intentional changes, just delete it and let it be recreated. + * This could be used as a starter for implemnting the same format in a + * different language. + * + * The String format is not expected to match other implementations, but it's + * easy enough to read, and should be as good as a hard coded example either + * in code or a README file. + */ + def goldenBehavior[T](example: Example[T]) = { + + def resourcePath(extension: String): Path = { + val goldenPath = Path.of(getClass.getResource("/golden").toURI) + goldenPath.resolve(s"${example.name}.${extension}") + } + + def maybeCreateResource(path: Path, content: => String) = { + if (!Files.exists(path)) { + Files.writeString(path, content, StandardOpenOption.CREATE_NEW) + } + } + + val goldenRlpPath = resourcePath("rlp") + val goldenTxtPath = resourcePath("txt") + + maybeCreateResource( + goldenRlpPath, + BitVector(rlp.encode(example.encoded)).toHex + ) + + maybeCreateResource( + goldenTxtPath, + example.decoded.toString + ) + + it should "decode the golden RLP content to a value that matches the golden String" in { + val goldenRlp = BitVector.fromHex(Files.readString(goldenRlpPath)).get + val goldenTxt = Files.readString(goldenTxtPath) + + example.decode(goldenRlp).toString shouldBe goldenTxt + } + } + + def test[T](example: Example[T]) = { + example.name should behave like exampleBehavior(example) + example.name should behave like goldenBehavior(example) + } + + test { + new Example[Ledger] { + override val decoded = Ledger( + maybeLastCheckpoint = Some( + sample[Transaction.CheckpointCandidate] + ), + proposerBlocks = Vector( + sample[Transaction.ProposerBlock], + sample[Transaction.ProposerBlock] + ) + ) + + override val encoded = + RLPList( // Ledger + RLPList( // Option + RLPList( // CheckpointCandidate + RLPValue(decoded.maybeLastCheckpoint.get.value.toByteArray) + ) + ), + RLPList( // Vector + RLPList( // ProposerBlock + RLPValue(decoded.proposerBlocks(0).value.toByteArray) + ), + RLPList(RLPValue(decoded.proposerBlocks(1).value.toByteArray)) + ) + ) + } + } + + test { + new Example[Transaction] { + override val decoded = sample[Transaction.ProposerBlock] + + override val encoded = + RLPList( // ProposerBlock + RLPValue(Array(1.toByte)), // Tag + RLPValue(decoded.value.toByteArray) + ) + } + } + + test { + new Example[CheckpointCertificate] { + val decoded = CheckpointCertificate( + headers = NonEmptyList.of( + sample[Block.Header], + sample[Block.Header] + ), + checkpoint = sample[Transaction.CheckpointCandidate], + proof = MerkleTree.Proof( + leafIndex = 2, + siblingPath = Vector(sample[MerkleTree.Hash], sample[MerkleTree.Hash]) + ), + commitQC = QuorumCertificate[CheckpointingAgreement]( + phase = Phase.Commit, + viewNumber = ViewNumber(10), + blockHash = sample[Block.Header.Hash], + signature = GroupSignature( + List( + sample[ECDSASignature], + sample[ECDSASignature] + ) + ) + ) + ) + + override val encoded = + RLPList( // CheckpointCertificate + RLPList( // NonEmptyList + RLPList( // BlockHeader + RLPValue(decoded.headers.head.parentHash.toArray), + RLPValue( + rlp.RLPImplicits.longEncDec + .encode(decoded.headers.head.height) + .bytes + ), + RLPValue(decoded.headers.head.postStateHash.toArray), + RLPValue(decoded.headers.head.contentMerkleRoot.toArray) + ), + RLPList( // BlockHeader + RLPValue(decoded.headers.last.parentHash.toArray), + RLPValue( + rlp.RLPImplicits.longEncDec + .encode(decoded.headers.last.height) + .bytes + ), + RLPValue(decoded.headers.last.postStateHash.toArray), + RLPValue(decoded.headers.last.contentMerkleRoot.toArray) + ) + ), + RLPList( // CheckpointCandidate + RLPValue(decoded.checkpoint.value.toByteArray) + ), + RLPList( // Proof + RLPValue(Array(decoded.proof.leafIndex.toByte)), + RLPList( // siblingPath + RLPValue(decoded.proof.siblingPath.head.toArray), + RLPValue(decoded.proof.siblingPath.last.toArray) + ) + ), + RLPList( // QuorumCertificate + RLPValue(Array(3.toByte)), // Commit + RLPValue(Array(10.toByte)), // ViewNumber + RLPValue(decoded.commitQC.blockHash.toArray), + RLPList( // GroupSignature + RLPList( // sig + RLPValue( // ECDSASignature + decoded.commitQC.signature.sig.head.toBytes.toArray[Byte] + ), + RLPValue( + decoded.commitQC.signature.sig.last.toBytes.toArray[Byte] + ) + ) + ) + ) + ) + } + } +} diff --git a/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/CheckpointingService.scala b/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/CheckpointingService.scala new file mode 100644 index 00000000..42bc2a92 --- /dev/null +++ b/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/CheckpointingService.scala @@ -0,0 +1,323 @@ +package io.iohk.metronome.checkpointing.service + +import cats.data.{NonEmptyList, NonEmptyVector, OptionT} +import cats.effect.concurrent.Ref +import cats.effect.{Concurrent, Resource, Sync} +import cats.implicits._ +import io.iohk.metronome.checkpointing.CheckpointingAgreement +import io.iohk.metronome.checkpointing.models.Transaction.CheckpointCandidate +import io.iohk.metronome.checkpointing.models.{ + Block, + CheckpointCertificate, + Ledger +} +import io.iohk.metronome.checkpointing.service.CheckpointingService.{ + CheckpointData, + LedgerNode, + LedgerTree +} +import io.iohk.metronome.checkpointing.service.storage.LedgerStorage +import io.iohk.metronome.crypto.ECPublicKey +import io.iohk.metronome.hotstuff.consensus.basic.{Phase, QuorumCertificate} +import io.iohk.metronome.hotstuff.service.ApplicationService +import io.iohk.metronome.hotstuff.service.storage.{ + BlockStorage, + ViewStateStorage +} +import io.iohk.metronome.storage.KVStoreRunner + +import scala.annotation.tailrec + +class CheckpointingService[F[_]: Sync, N]( + ledgerTreeRef: Ref[F, LedgerTree], + lastCommittedHeaderRef: Ref[F, Block.Header], + checkpointDataRef: Ref[F, Option[CheckpointData]], + //TODO: PM-3137, this is used for testing that a certificate was created correctly + // replace with proper means of pushing the certificate to the Interpreter + pushCheckpointFn: CheckpointCertificate => F[Unit], + ledgerStorage: LedgerStorage[N], + blockStorage: BlockStorage[N, CheckpointingAgreement] +)(implicit storeRunner: KVStoreRunner[F, N]) + extends ApplicationService[F, CheckpointingAgreement] { + + override def createBlock( + highQC: QuorumCertificate[CheckpointingAgreement] + ): F[Option[Block]] = ??? + + override def validateBlock(block: Block): F[Option[Boolean]] = { + val ledgers = for { + nextLedger <- OptionT(projectLedger(block)) + tree <- OptionT.liftF(ledgerTreeRef.get) + prevLedger <- tree.get(block.header.parentHash).map(_.ledger).toOptionT[F] + } yield (prevLedger, nextLedger) + + ledgers.value.flatMap { + case Some((prevLedger, nextLedger)) + if nextLedger.hash == block.header.postStateHash => + validateTransactions(block.body, prevLedger) + + case _ => false.some.pure[F] + } + } + + private def validateTransactions( + body: Block.Body, + ledger: Ledger + ): F[Option[Boolean]] = { + //TODO: Validate transactions PM-3131/3132 + true.some.pure[F] + } + + override def executeBlock( + block: Block, + commitQC: QuorumCertificate[CheckpointingAgreement], + commitPath: NonEmptyList[Block.Hash] + ): F[Boolean] = { + require(commitQC.phase == Phase.Commit, "Commit QC required") + projectLedger(block).flatMap { + case Some(ledger) => + updateCheckpointData(block).flatMap { checkpointDataOpt => + if (block.hash != commitQC.blockHash) + false.pure[F] + else { + val certificateOpt = checkpointDataOpt + .flatMap { cd => + CheckpointCertificate + .construct(cd.block, cd.headers.toNonEmptyList, commitQC) + } + .toOptionT[F] + + saveLedger(block.header, ledger) >> + certificateOpt.cataF(().pure[F], pushCheckpoint) >> + true.pure[F] + } + } + + case None => + Sync[F].raiseError( + new IllegalStateException(s"Could not execute block: ${block.hash}") + ) + } + } + + /** Computes and saves the intermediate ledgers leading up to and including + * the one resulting from the `block` transactions, either by looking up + * already computed ledgers in the `ledgerTree` or fetching ancestor blocks + * from `blockStorage`. + * Only descendants of the root of the `ledgerTree` (last committed ledger) + * will be evaluated + */ + private def projectLedger(block: Block): F[Option[Ledger]] = { + (for { + ledgerTree <- ledgerTreeRef.get + commitHeight <- lastCommittedHeaderRef.get.map(_.height) + } yield { + def loop(block: Block): OptionT[F, Ledger] = { + def doUpdate(ledger: Ledger) = + OptionT.liftF(updateLedgerByBlock(ledger, block)) + + ledgerTree.get(block.header.parentHash) match { + case Some(oldLedger) => + doUpdate(oldLedger.ledger) + + case None if block.header.height <= commitHeight => + OptionT.none + + case None => + for { + parent <- OptionT(getBlock(block.header.parentHash)) + oldLedger <- loop(parent) + newLedger <- doUpdate(oldLedger) + } yield newLedger + } + } + + ledgerTree + .get(block.hash) + .map(_.ledger) + .toOptionT[F] + .orElse(loop(block)) + .value + }).flatten + } + + /** Computes a new ledger from the `block` and saves it in the ledger tree only if + * a parent state exists. + * + * Because we're only adding to the tree no locking around it is necessary + */ + private def updateLedgerByBlock( + oldLedger: Ledger, + block: Block + ): F[Ledger] = { + val newLedger = oldLedger.update(block.body.transactions) + + ledgerTreeRef + .update { tree => + if (tree.contains(block.header.parentHash)) + tree + (block.hash -> LedgerNode(newLedger, block.header)) + else + tree + } + .as(newLedger) + } + + private def updateCheckpointData( + block: Block + ): F[Option[CheckpointData]] = { + val containsCheckpoint = block.body.transactions.exists { + case _: CheckpointCandidate => true + case _ => false + } + + checkpointDataRef.updateAndGet { cd => + if (containsCheckpoint) + CheckpointData(block).some + else + cd.map(_.extend(block.header)) + } + } + + private def getBlock(hash: Block.Hash): F[Option[Block]] = + storeRunner.runReadOnly(blockStorage.get(hash)) + + private def saveLedger(header: Block.Header, ledger: Ledger): F[Unit] = { + storeRunner.runReadWrite { + ledgerStorage.put(ledger) + } >> + ledgerTreeRef.update(clearLedgerTree(header, ledger)) >> + lastCommittedHeaderRef.set(header) >> + checkpointDataRef.set(None) + } + + /** Makes the `commitHeader` and the associated 'ledger' the root of the tree, + * while retaining any descendants of the `commitHeader` + */ + private def clearLedgerTree(commitHeader: Block.Header, ledger: Ledger)( + ledgerTree: LedgerTree + ): LedgerTree = { + + @tailrec + def loop( + oldTree: LedgerTree, + newTree: LedgerTree, + height: Long + ): LedgerTree = + if (oldTree.isEmpty) newTree + else { + val (higherLevels, currentLevel) = oldTree.partition { case (_, ln) => + ln.height > height + } + val children = currentLevel.filter { case (_, ln) => + newTree.contains(ln.parentHash) + } + loop(higherLevels, newTree ++ children, height + 1) + } + + loop( + ledgerTree.filter { case (_, ln) => ln.height > commitHeader.height }, + LedgerTree.root(ledger, commitHeader), + commitHeader.height + 1 + ) + } + + private def pushCheckpoint(checkpoint: CheckpointCertificate): F[Unit] = + pushCheckpointFn(checkpoint) //TODO: PM-3137 + + override def syncState( + sources: NonEmptyVector[ECPublicKey], + block: Block + ): F[Boolean] = ??? +} + +object CheckpointingService { + + /** A node in LedgerTree + * `parentHash` and `height` are helpful when resetting the tree + */ + case class LedgerNode( + ledger: Ledger, + parentHash: Block.Hash, + height: Long + ) + + object LedgerNode { + def apply(ledger: Ledger, header: Block.Header): LedgerNode = + LedgerNode(ledger, header.parentHash, header.height) + } + + /** The internal structure used to represent intermediate ledgers resulting + * from execution and validation + */ + type LedgerTree = Map[Block.Hash, LedgerNode] + + object LedgerTree { + def root(ledger: Ledger, header: Block.Header): LedgerTree = + Map(header.hash -> LedgerNode(ledger, header)) + } + + /** Used to track the most recent checkpoint candidate + * `block` - last containing a checkpoint candidate + * `headers` - path from the `block` to the last executed one + * + * These values along with Commit QC can be used to construct + * a `CheckpointCertificate` + */ + case class CheckpointData( + block: Block, + headers: NonEmptyVector[Block.Header] + ) { + def extend(header: Block.Header): CheckpointData = + copy(headers = headers :+ header) + } + + object CheckpointData { + def apply(block: Block): CheckpointData = + CheckpointData(block, NonEmptyVector.of(block.header)) + } + + def apply[F[_]: Concurrent, N]( + ledgerStorage: LedgerStorage[N], + blockStorage: BlockStorage[N, CheckpointingAgreement], + viewStateStorage: ViewStateStorage[N, CheckpointingAgreement], + pushCheckpointFn: CheckpointCertificate => F[Unit] + )(implicit + storeRunner: KVStoreRunner[F, N] + ): Resource[F, CheckpointingService[F, N]] = { + val lastExecuted: F[(Block, Ledger)] = + storeRunner.runReadOnly { + val query = for { + blockHash <- OptionT.liftF( + viewStateStorage.getLastExecutedBlockHash + ) + block <- OptionT(blockStorage.get(blockHash)) + //a genesis (empty) state should be present in LedgerStorage on first run + ledger <- OptionT(ledgerStorage.get(block.header.postStateHash)) + } yield (block, ledger) + query.value + } >>= { + _.toOptionT[F].getOrElseF { + Sync[F].raiseError( + new IllegalStateException("Last executed block/state not found") + ) + } + } + + val service = for { + (block, ledger) <- lastExecuted + ledgerTree <- Ref.of(LedgerTree.root(ledger, block.header)) + lastExec <- Ref.of(block.header) + checkpointData <- Ref.of(None: Option[CheckpointData]) + } yield new CheckpointingService[F, N]( + ledgerTree, + lastExec, + checkpointData, + pushCheckpointFn, + ledgerStorage, + blockStorage + ) + + Resource.liftF(service) + } + +} diff --git a/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/messages/CheckpointingMessage.scala b/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/messages/CheckpointingMessage.scala new file mode 100644 index 00000000..de7bb493 --- /dev/null +++ b/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/messages/CheckpointingMessage.scala @@ -0,0 +1,31 @@ +package io.iohk.metronome.checkpointing.service.messages + +import io.iohk.metronome.checkpointing.models.Ledger +import io.iohk.metronome.core.messages.{RPCMessage, RPCMessageCompanion} + +/** Checkpointing specific messages that the HotStuff service doesn't handle, + * which is the synchronisation of committed ledger state. + * + * These will be wrapped in an `ApplicationMessage`. + */ +sealed trait CheckpointingMessage { self: RPCMessage => } + +object CheckpointingMessage extends RPCMessageCompanion { + + /** Request the ledger state given by a specific hash. + * + * The hash is something coming from a block that was + * pointed at by a Commit Q.C. + */ + case class GetStateRequest( + requestId: RequestId, + stateHash: Ledger.Hash + ) extends CheckpointingMessage + with Request + + case class GetStateResponse( + requestId: RequestId, + state: Ledger + ) extends CheckpointingMessage + with Response +} diff --git a/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/storage/LedgerStorage.scala b/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/storage/LedgerStorage.scala new file mode 100644 index 00000000..ff11d4ef --- /dev/null +++ b/metronome/checkpointing/service/src/io/iohk/metronome/checkpointing/service/storage/LedgerStorage.scala @@ -0,0 +1,43 @@ +package io.iohk.metronome.checkpointing.service.storage + +import cats.implicits._ +import io.iohk.metronome.checkpointing.models.Ledger +import io.iohk.metronome.storage.{KVRingBuffer, KVCollection, KVStore} +import scodec.Codec + +/** Storing the committed and executed checkpoint ledger. + * + * Strictly speaking the application only needs the committed state, + * since it has been signed by the federation and we know it's not + * going to be rolled back. Uncommitted state can be kept in memory. + * + * However we want to support other nodes catching up by: + * 1. requesting the latest Commit Q.C., then + * 2. requesting the block the Commit Q.C. points at, then + * 3. requesting the ledger state the header points at. + * + * We have to allow some time before we get rid of historical state, + * so that it doesn't disappear between step 2 and 3, resulting in + * nodes trying and trying to catch up but always missing the beat. + * + * Therefore we keep a collection of the last N ledgers in a ring buffer. + */ +class LedgerStorage[N]( + ledgerColl: KVCollection[N, Ledger.Hash, Ledger], + ledgerMetaNamespace: N, + maxHistorySize: Int +)(implicit codecH: Codec[Ledger.Hash]) + extends KVRingBuffer[N, Ledger.Hash, Ledger]( + ledgerColl, + ledgerMetaNamespace, + maxHistorySize + ) { + + /** Save a new ledger and remove the oldest one, if we reached + * the maximum history size. Since we only store committed + * state, they form a chain. They will always be retrieved + * by going through a block pointing at them directly. + */ + def put(ledger: Ledger): KVStore[N, Unit] = + put(ledger.hash, ledger).void +} diff --git a/metronome/checkpointing/service/test/src/io/iohk/metronome/checkpointing/service/CheckpointingServiceProps.scala b/metronome/checkpointing/service/test/src/io/iohk/metronome/checkpointing/service/CheckpointingServiceProps.scala new file mode 100644 index 00000000..f079c6c5 --- /dev/null +++ b/metronome/checkpointing/service/test/src/io/iohk/metronome/checkpointing/service/CheckpointingServiceProps.scala @@ -0,0 +1,487 @@ +package io.iohk.metronome.checkpointing.service + +import cats.data.NonEmptyList +import cats.effect.Resource +import cats.effect.concurrent.Ref +import cats.implicits._ +import io.iohk.metronome.checkpointing.CheckpointingAgreement +import io.iohk.metronome.checkpointing.models.Block.{Hash, Header} +import io.iohk.metronome.checkpointing.models.Transaction.CheckpointCandidate +import io.iohk.metronome.checkpointing.models.{ + ArbitraryInstances, + Block, + CheckpointCertificate, + Ledger +} +import io.iohk.metronome.checkpointing.service.CheckpointingService.{ + CheckpointData, + LedgerNode, + LedgerTree +} +import io.iohk.metronome.checkpointing.service.storage.LedgerStorage +import io.iohk.metronome.checkpointing.service.storage.LedgerStorageProps.{ + neverUsedCodec, + Namespace => LedgerNamespace +} +import io.iohk.metronome.hotstuff.consensus.basic.Phase.Commit +import io.iohk.metronome.hotstuff.consensus.basic.QuorumCertificate +import io.iohk.metronome.hotstuff.service.storage.BlockStorage +import io.iohk.metronome.hotstuff.service.storage.BlockStorageProps.{ + Namespace => BlockNamespace +} +import io.iohk.metronome.storage.{ + InMemoryKVStore, + KVCollection, + KVStoreRunner, + KVStoreState, + KVTree +} +import monix.eval.Task +import monix.execution.Scheduler +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Prop.{all, classify, forAll, forAllNoShrink, propBoolean} +import org.scalacheck.{Gen, Prop, Properties} + +import scala.concurrent.duration._ +import scala.util.Random + +/** Props for Checkpointing service + * + * Do take note of tests that use `classify` to report whether parallelism + * was achieved. This is not a hard requirement because it may fail on CI, + * but one should make sure to achieve 100% parallelism locally when making + * changes to this tests or the service + */ +class CheckpointingServiceProps extends Properties("CheckpointingService") { + + type Namespace = String + + case class TestResources( + checkpointingService: CheckpointingService[Task, Namespace], + ledgerStorage: LedgerStorage[Namespace], + blockStorage: BlockStorage[Namespace, CheckpointingAgreement], + store: KVStoreRunner[Task, Namespace], + ledgerTreeRef: Ref[Task, LedgerTree], + checkpointDataRef: Ref[Task, Option[CheckpointData]], + lastCheckpointCertRef: Ref[Task, Option[CheckpointCertificate]] + ) + + case class TestFixture( + initialBlock: Block, + initialLedger: Ledger, + batch: List[Block], + commitQC: QuorumCertificate[CheckpointingAgreement], + randomSeed: Long + ) { + val resources: Resource[Task, TestResources] = { + val ledgerStorage = + new LedgerStorage[Namespace]( + new KVCollection[Namespace, Ledger.Hash, Ledger]( + LedgerNamespace.Ledgers + ), + LedgerNamespace.LedgerMeta, + maxHistorySize = 10 + ) + + val blockStorage = new BlockStorage[Namespace, CheckpointingAgreement]( + new KVCollection[Namespace, Block.Hash, Block](BlockNamespace.Blocks), + new KVCollection[Namespace, Block.Hash, KVTree.NodeMeta[Hash]]( + BlockNamespace.BlockMetas + ), + new KVCollection[Namespace, Block.Hash, Set[Block.Hash]]( + BlockNamespace.BlockToChildren + ) + ) + + implicit val store = InMemoryKVStore[Task, Namespace]( + Ref.unsafe[Task, KVStoreState[Namespace]#Store](Map.empty) + ) + + Resource.liftF { + for { + _ <- store.runReadWrite { + ledgerStorage.put(initialLedger.hash, initialLedger) >> + blockStorage.put(initialBlock) + } + + ledgerTree <- Ref.of[Task, LedgerTree]( + LedgerTree.root(initialLedger, initialBlock.header) + ) + lastExec <- Ref.of[Task, Header](initialBlock.header) + chkpData <- Ref.of[Task, Option[CheckpointData]](None) + lastCert <- Ref.of[Task, Option[CheckpointCertificate]](None) + + service = new CheckpointingService[Task, Namespace]( + ledgerTree, + lastExec, + chkpData, + cc => lastCert.set(cc.some), + ledgerStorage, + blockStorage + ) + + } yield TestResources( + service, + ledgerStorage, + blockStorage, + store, + ledgerTree, + chkpData, + lastCert + ) + } + } + + // not used in the impl so a senseless value + val commitPath = NonEmptyList.one(initialBlock.header.parentHash) + + lazy val allTransactions = batch.flatMap(_.body.transactions) + lazy val finalLedger = + initialLedger.update(batch.flatMap(_.body.transactions)) + + lazy val expectedCheckpointCert = allTransactions.reverse.collectFirst { + case candidate: CheckpointCandidate => + //apparently identical checkpoints can be generated in different blocks + val blockPath = batch.drop( + batch.lastIndexWhere(_.body.transactions.contains(candidate)) + ) + val headerPath = NonEmptyList.fromListUnsafe(blockPath.map(_.header)) + + CheckpointCertificate.construct(blockPath.head, headerPath, commitQC) + }.flatten + } + + object TestFixture { + import ArbitraryInstances._ + + def gen(minChain: Int = 1): Gen[TestFixture] = { + for { + block <- arbitrary[Block] + ledger = Ledger.empty.update(block.body.transactions) + batch <- genBlockChain(block, ledger, min = minChain) + commitQC <- genCommitQC(batch.last) + seed <- Gen.posNum[Long] + } yield TestFixture(block, ledger, batch, commitQC, seed) + } + + def genBlockChain( + parent: Block, + initialLedger: Ledger, + min: Int = 1, + max: Int = 6 + ): Gen[List[Block]] = { + for { + n <- Gen.choose(min, max) + blocks <- Gen.listOfN(n, arbitrary[Block]) + } yield { + def link( + parent: Block, + prevLedger: Ledger, + chain: List[Block] + ): List[Block] = chain match { + case b :: bs => + val nextLedger = prevLedger.update(b.body.transactions) + val header = b.header.copy( + parentHash = parent.hash, + height = parent.header.height + 1, + postStateHash = nextLedger.hash + ) + val linked = Block.makeUnsafe(header, b.body) + linked :: link(linked, nextLedger, bs) + case Nil => + Nil + } + + link(parent, initialLedger, blocks) + } + } + + def genCommitQC( + block: Block + ): Gen[QuorumCertificate[CheckpointingAgreement]] = + arbitrary[QuorumCertificate[CheckpointingAgreement]].map { + _.copy[CheckpointingAgreement](phase = Commit, blockHash = block.hash) + } + } + + def run(fixture: TestFixture)(test: TestResources => Task[Prop]): Prop = { + import Scheduler.Implicits.global + + fixture.resources.use(test).runSyncUnsafe(timeout = 5.seconds) + } + + property("normal execution") = forAll(TestFixture.gen()) { fixture => + run(fixture) { res => + import fixture._ + import res._ + + val execution = batch + .map(checkpointingService.executeBlock(_, commitQC, commitPath)) + .sequence + + val ledgerStorageCheck = store.runReadOnly { + ledgerStorage.get(finalLedger.hash) + } + + for { + results <- execution + persistedLedger <- ledgerStorageCheck + ledgerTree <- ledgerTreeRef.get + lastCheckpoint <- lastCheckpointCertRef.get + checkpointData <- checkpointDataRef.get + } yield { + val ledgerTreeUpdated = + ledgerTree == LedgerTree.root(finalLedger, batch.last.header) + + val executionSuccessful = results.reverse match { + case true :: rest => !rest.exists(identity) + case _ => false + } + + all( + "execution successful" |: executionSuccessful, + "ledger persisted" |: persistedLedger.contains(finalLedger), + "ledgerTree updated" |: ledgerTreeUpdated, + "checkpoint constructed correctly" |: lastCheckpoint == expectedCheckpointCert, + "checkpoint data cleared" |: checkpointData.isEmpty + ) + } + } + } + + property("interrupted execution") = forAll(TestFixture.gen(minChain = 2)) { + fixture => + run(fixture) { res => + import fixture._ + import res._ + + // not executing the committed block + val execution = batch.init + .map(checkpointingService.executeBlock(_, commitQC, commitPath)) + .sequence + + for { + results <- execution + ledgerTree <- ledgerTreeRef.get + lastCheckpoint <- lastCheckpointCertRef.get + } yield { + val ledgerTreeUpdated = + batch.init.map(_.hash).forall(ledgerTree.contains) + val executionSuccessful = !results.exists(identity) + + all( + "executed correctly" |: executionSuccessful, + "ledgerTree updated" |: ledgerTreeUpdated, + "checkpoint constructed correctly" |: lastCheckpoint.isEmpty + ) + } + } + } + + property("failed execution - no parent") = + forAll(TestFixture.gen(minChain = 2)) { fixture => + run(fixture) { res => + import fixture._ + import res._ + + // parent block or its state is not saved so this must fail + val execution = batch.tail + .map(checkpointingService.executeBlock(_, commitQC, commitPath)) + .sequence + + execution.attempt.map { + case Left(ex: IllegalStateException) => + ex.getMessage.contains("Could not execute block") + case _ => false + } + } + } + + property("failed execution - height below last executed") = + forAll(TestFixture.gen(minChain = 2)) { fixture => + run(fixture) { res => + import fixture._ + import res._ + + val execution = batch + .map(checkpointingService.executeBlock(_, commitQC, commitPath)) + .sequence + + // repeated execution must fail because we're trying to execute a block of lower height + // than the last executed block + execution >> + execution.attempt.map { + case Left(ex: IllegalStateException) => + ex.getMessage.contains("Could not execute block") + case _ => false + } + } + } + + //TODO: Validate transactions PM-3131/3132 + // use a mocked interpreter client that always evaluates blocks as valid + property("parallel validation") = forAll(TestFixture.gen(minChain = 4)) { + fixture => + run(fixture) { res => + import fixture._ + import res._ + + // validation in random order so blocks need to be persisted first + val persistBlocks = store.runReadWrite { + batch.map(blockStorage.put).sequence + } + + def validation( + validating: Ref[Task, Boolean], + achievedPar: Ref[Task, Boolean] + ) = + Task.parSequence { + new Random(randomSeed) + .shuffle(batch) + .map(b => + for { + v <- validating.getAndSet(true) + _ <- achievedPar.update(_ || v) + r <- checkpointingService.validateBlock(b) + _ <- validating.set(false) + } yield r.getOrElse(false) + ) + } + + for { + _ <- persistBlocks + + // used to make sure that parallelism was achieved + validating <- Ref[Task].of(false) + achievedPar <- Ref[Task].of(false) + + result <- validation(validating, achievedPar) + par <- achievedPar.get + ledgerTree <- ledgerTreeRef.get + } yield { + val ledgerTreeUpdated = batch.forall(b => ledgerTree.contains(b.hash)) + + classify(par, "parallelism achieved") { + all( + "validation successful" |: result.forall(identity), + "ledgerTree updated" |: ledgerTreeUpdated + ) + } + } + } + } + + //TODO: Validate transactions PM-3131/3132 + // use a mocked interpreter client that always evaluates blocks as valid + property("execution parallel to validation") = forAllNoShrink { + for { + f <- TestFixture.gen(minChain = 4) + ext <- TestFixture.genBlockChain(f.batch.last, f.finalLedger) + } yield (f, f.batch ++ ext) + } { case (fixture, validationBatch) => + run(fixture) { res => + import fixture._ + import res._ + + // validation in random order so blocks need to be persisted first + val persistBlocks = store.runReadWrite { + validationBatch.map(blockStorage.put).sequence + } + + def validation( + validating: Ref[Task, Boolean], + executing: Ref[Task, Boolean], + achievedPar: Ref[Task, Boolean] + ) = { + new Random(randomSeed) + .shuffle(validationBatch) + .map(b => + for { + _ <- validating.set(true) + e <- executing.get + _ <- achievedPar.update(_ || e) + r <- checkpointingService.validateBlock(b) + _ <- validating.set(false) + } yield (r.getOrElse(false), b.header.height) + ) + .sequence + } + + def execution( + validating: Ref[Task, Boolean], + executing: Ref[Task, Boolean], + achievedPar: Ref[Task, Boolean] + ) = + batch + .map(b => + for { + _ <- executing.set(true) + v <- validating.get + _ <- achievedPar.update(_ || v) + r <- checkpointingService.executeBlock(b, commitQC, commitPath) + _ <- executing.set(false) + } yield r + ) + .sequence + + val ledgerStorageCheck = store.runReadOnly { + ledgerStorage.get(finalLedger.hash) + } + + for { + _ <- persistBlocks + + // used to make sure that parallelism was achieved + validating <- Ref[Task].of(false) + executing <- Ref[Task].of(false) + achievedPar <- Ref[Task].of(false) + + (validationRes, executionRes) <- Task.parZip2( + validation(validating, executing, achievedPar), + execution(validating, executing, achievedPar) + ) + + par <- achievedPar.get + persistedLedger <- ledgerStorageCheck + ledgerTree <- ledgerTreeRef.get + lastCheckpoint <- lastCheckpointCertRef.get + checkpointData <- checkpointDataRef.get + } yield { + val validationsAfterExec = validationRes.collect { + case (r, h) if h > batch.last.header.height => r + } + + val executionSuccessful = executionRes.reverse match { + case true :: rest => !rest.exists(identity) + case _ => false + } + + val ledgerTreeReset = batch.reverse match { + case committed :: rest => + ledgerTree + .get(committed.hash) + .contains(LedgerNode(finalLedger, committed.header)) && + rest.forall(b => !ledgerTree.contains(b.hash)) + + case _ => false + } + + val validationsSaved = + validationBatch.diff(batch).forall(b => ledgerTree.contains(b.hash)) + + classify(par, "parallelism achieved") { + all( + "validation successful" |: validationsAfterExec.forall(identity), + "execution successful" |: executionSuccessful, + "ledger persisted" |: persistedLedger.contains(finalLedger), + "ledgerTree reset" |: ledgerTreeReset, + "ledgerTree contains validations" |: validationsSaved, + "checkpoint constructed correctly" |: lastCheckpoint == expectedCheckpointCert, + "checkpoint data cleared" |: checkpointData.isEmpty + ) + } + } + } + } + +} diff --git a/metronome/checkpointing/service/test/src/io/iohk/metronome/checkpointing/service/storage/LedgerStorageProps.scala b/metronome/checkpointing/service/test/src/io/iohk/metronome/checkpointing/service/storage/LedgerStorageProps.scala new file mode 100644 index 00000000..88261834 --- /dev/null +++ b/metronome/checkpointing/service/test/src/io/iohk/metronome/checkpointing/service/storage/LedgerStorageProps.scala @@ -0,0 +1,81 @@ +package io.iohk.metronome.checkpointing.service.storage + +import cats.implicits._ +import io.iohk.metronome.core.Tagger +import io.iohk.metronome.checkpointing.models.Ledger +import io.iohk.metronome.checkpointing.models.ArbitraryInstances +import io.iohk.metronome.storage.{KVCollection, KVStoreState} +import org.scalacheck.{Properties, Gen, Arbitrary}, Arbitrary.arbitrary +import org.scalacheck.Prop.{forAll, all, propBoolean} +import scodec.Codec +import scodec.bits.BitVector +import org.scalacheck.Shrink +import scala.annotation.nowarn + +object LedgerStorageProps extends Properties("LedgerStorage") { + import ArbitraryInstances.arbLedger + + type Namespace = String + object Namespace { + val Ledgers = "ledgers" + val LedgerMeta = "ledger-meta" + } + + /** The in-memory KVStoreState doesn't invoke the codecs. */ + implicit def neverUsedCodec[T] = + Codec[T]( + (_: T) => sys.error("Didn't expect to encode."), + (_: BitVector) => sys.error("Didn't expect to decode.") + ) + + object TestKVStore extends KVStoreState[Namespace] + + object HistorySize extends Tagger[Int] { + @nowarn + implicit val shrink: Shrink[HistorySize] = Shrink(s => Stream.empty) + implicit val arb: Arbitrary[HistorySize] = Arbitrary { + Gen.choose(1, 10).map(HistorySize(_)) + } + } + type HistorySize = HistorySize.Tagged + + property("buffer") = forAll( + for { + ledgers <- arbitrary[List[Ledger]] + maxSize <- arbitrary[HistorySize] + } yield (ledgers, maxSize) + ) { case (ledgers, maxSize) => + val ledgerStorage = new LedgerStorage[Namespace]( + new KVCollection[Namespace, Ledger.Hash, Ledger](Namespace.Ledgers), + Namespace.LedgerMeta, + maxHistorySize = maxSize + ) + + val store = + TestKVStore + .compile(ledgers.traverse(ledgerStorage.put)) + .runS(Map.empty) + .value + + def getByHash(ledgerHash: Ledger.Hash) = + TestKVStore.compile(ledgerStorage.get(ledgerHash)).run(store) + + val ledgerMap = store.get(Namespace.Ledgers).getOrElse(Map.empty[Any, Any]) + + val (current, old) = { + val (lastN, prev) = ledgers.reverse.splitAt(maxSize) + // There can be duplicates, re-insertions. + (lastN, prev.filterNot(lastN.contains)) + } + + all( + "max-history" |: ledgerMap.values.size <= maxSize, + "contains current" |: current.forall { ledger => + getByHash(ledger.hash).contains(ledger) + }, + "not contain old" |: old.forall { ledger => + getByHash(ledger.hash).isEmpty + } + ) + } +} diff --git a/metronome/config/src/io/iohk/metronome/config/ConfigDecoders.scala b/metronome/config/src/io/iohk/metronome/config/ConfigDecoders.scala new file mode 100644 index 00000000..7fb4dd26 --- /dev/null +++ b/metronome/config/src/io/iohk/metronome/config/ConfigDecoders.scala @@ -0,0 +1,82 @@ +package io.iohk.metronome.config + +import io.circe._ +import com.typesafe.config.{ConfigFactory, Config} +import scala.util.Try +import scala.concurrent.duration._ + +object ConfigDecoders { + + /** Parse a string into a TypeSafe config an use one of the accessor methods. */ + private def tryParse[T](value: String, f: (Config, String) => T): Try[T] = + Try { + val key = "dummy" + val conf = ConfigFactory.parseString(s"$key = $value") + f(conf, key) + } + + /** Parse HOCON byte sizes like "128M". */ + val bytesDecoder: Decoder[Long] = + Decoder[String].emapTry { + tryParse(_, _ getBytes _) + } + + /** Parse HOCON durations like "5m". */ + val durationDecoder: Decoder[FiniteDuration] = + Decoder[String].emapTry { + tryParse(_, _.getDuration(_).toMillis.millis) + } + + /** Parse an object where a discriminant tells us which other key value + * to deserialise into the target type. + * + * For example take the following config: + * + * ``` + * virus { + * variant = alpha + * alpha { + * r = 1.1 + * } + * delta { + * r = 1.4 + * } + * } + * ``` + * + * It should deserialize into a class that matches a sub-key: + * ``` + * case class Virus(r: Double) + * object Virus { + * implicit val decoder: Decoder[Virus] = + * ConfigDecoders.strategyDecoder[Virus]("variant", deriveDecoder) + * } + * ``` + * + * The decoder will deserialise all the other keys as well to make sure + * that all of them are valid, in case the selection changes. + */ + def strategyDecoder[T]( + discriminant: String, + decoder: Decoder[T] + ): Decoder[T] = { + // This parser is applied after the fields have been transformed to camelCase. + import ConfigParser.toCamelCase + // Not passing the decoder implicitly so the compiler doesn't pass + // the one we are constructing here. + implicit val inner: Decoder[T] = decoder + + Decoder.instance[T] { (c: HCursor) => + for { + obj <- c.value.as[JsonObject] + ccd = toCamelCase(discriminant) + selected <- c.downField(ccd).as[String].map(toCamelCase) + value <- c.downField(selected).as[T] + // Making sure that everything else is valid. We could pick the value from the result, + // but it's more difficult to provide the right `DecodingFailure` with a list of operations + // if the selected key is not present in the map. + _ <- Json.fromJsonObject(obj.remove(ccd)).as[Map[String, T]] + } yield value + } + } +} diff --git a/metronome/config/src/io/iohk/metronome/config/ConfigParser.scala b/metronome/config/src/io/iohk/metronome/config/ConfigParser.scala new file mode 100644 index 00000000..5df4e42d --- /dev/null +++ b/metronome/config/src/io/iohk/metronome/config/ConfigParser.scala @@ -0,0 +1,185 @@ +package io.iohk.metronome.config + +import cats.implicits._ +import com.typesafe.config.{ConfigObject, ConfigRenderOptions} +import io.circe.{Json, JsonObject, ParsingFailure, Decoder, DecodingFailure} +import io.circe.parser.{parse => parseJson} + +object ConfigParser { + protected[config] type ParsingResult = Either[ParsingFailure, Json] + + type Result[T] = Either[Either[ParsingFailure, DecodingFailure], T] + + /** Parse configuration into a type using a JSON decoder, thus allowing + * validations to be applied to all configuraton values up front, rather + * than fail lazily when something is accessed or instantiated from + * the config factory. + * + * Accept overrides from the environment in PREFIX_PATH_TO_FIELD format. + */ + def parse[T: Decoder]( + conf: ConfigObject, + prefix: String = "", + env: Map[String, String] = sys.env + ): Result[T] = { + // Render the whole config to JSON. Everything needs a default value, + // but it can be `null` and be replaced from the environment. + val orig = toJson(conf) + // Transform fields which use dash for segmenting into camelCase. + val withCamel = withCamelCase(orig) + // Apply overrides from env vars. + val withEnv = withEnvVarOverrides(withCamel, prefix, env) + // Map to the domain config model. + withEnv match { + case Left(error) => Left(Left(error)) + case Right(json) => + Decoder[T].decodeJson(json) match { + case Left(error) => Left(Right(error)) + case Right(value) => Right(value) + } + } + } + + /** Render a TypeSafe Config section into JSON. */ + protected[config] def toJson(conf: ConfigObject): Json = { + val raw = conf.render(ConfigRenderOptions.concise) + parseJson(raw) match { + case Left(error: ParsingFailure) => + // This shouldn't happen with a well formed config file, + // which would have already failed during parsing or projecting + // to a `ConfigObject` passed to this method. + throw new IllegalArgumentException(error.message, error.underlying) + + case Right(json) => + json + } + } + + /** Transform a key in the HOCON config file to camelCase. */ + protected[config] def toCamelCase(key: String): String = { + def loop(cs: List[Char]): List[Char] = + cs match { + case ('_' | '-') :: cs => + cs match { + case c :: cs => c.toUpper :: loop(cs) + case Nil => Nil + } + case c :: cs => c :: loop(cs) + case Nil => Nil + } + + loop(key.toList).mkString + } + + /** Turn `camelCaseKey` into `SNAKE_CASE_KEY`, + * which is what it would look like as an env var. + */ + protected[config] def toSnakeCase(camelCase: String): String = { + def loop(cs: List[Char]): List[Char] = + cs match { + case a :: b :: cs if a.isLower && b.isUpper => + a.toUpper :: '_' :: b :: loop(cs) + case '-' :: cs => + '_' :: loop(cs) + case a :: cs => + a.toUpper :: loop(cs) + case Nil => + Nil + } + + loop(camelCase.toList).mkString + } + + /** Transform all keys into camelCase form, + * so they can be matched to case class fields. + */ + protected[config] def withCamelCase(json: Json): Json = { + json + .mapArray { arr => + arr.map(withCamelCase) + } + .mapObject { obj => + JsonObject(obj.toIterable.map { case (key, value) => + toCamelCase(key) -> withCamelCase(value) + }.toList: _*) + } + } + + /** Apply overrides from the environment to a JSON structure. + * + * Only considers env var keys that start with prefix and are + * in a PREFIX_SNAKE_CASE format. + * + * The operation can fail if a value in the environment is + * incompatible with the default in the config files. + * + * Default values in the config file are necessary, because + * the environment variable name in itself doesn't uniquely + * define a data structure (a single underscore matches both + * a '.' or a '-' in the path). + */ + protected[config] def withEnvVarOverrides( + json: Json, + prefix: String, + env: Map[String, String] = sys.env + ): ParsingResult = { + def extend(path: String, key: String) = + if (path.isEmpty) key else s"${path}_${key}" + + def loop(json: Json, path: String): ParsingResult = { + + def tryParse( + default: => Json, + validate: Json => Boolean + ): ParsingResult = + env + .get(path) + .map { value => + val maybeJson = parseJson(value) orElse parseJson(s""""$value"""") + + maybeJson.flatMap { json => + if (validate(json)) { + Right(json) + } else { + val msg = s"Invalid value for $path: $value" + Left(ParsingFailure(value, new IllegalArgumentException(msg))) + } + } + } + .getOrElse(Right(default)) + + json + .fold[ParsingResult]( + jsonNull = tryParse(Json.Null, _ => true), + jsonBoolean = x => tryParse(Json.fromBoolean(x), _.isBoolean), + jsonNumber = x => tryParse(Json.fromJsonNumber(x), _.isNumber), + jsonString = x => tryParse(Json.fromString(x), _.isString), + jsonArray = { arr => + arr.zipWithIndex + .map { case (value, idx) => + loop(value, extend(path, idx.toString)) + } + .sequence + .map { values => + Json.arr(values: _*) + } + }, + jsonObject = { obj => + obj.toIterable + .map { case (key, value) => + val snakeKey = toSnakeCase(key) + loop(value, extend(path, snakeKey)).map(key ->) + } + .toList + .sequence + .map { values => + Json.obj(values: _*) + } + } + ) + } + + loop(json, prefix) + } + +} diff --git a/metronome/config/test/resources/complex.conf b/metronome/config/test/resources/complex.conf new file mode 100644 index 00000000..34fe76f9 --- /dev/null +++ b/metronome/config/test/resources/complex.conf @@ -0,0 +1,27 @@ +metronome { + metrics { + enabled = false + } + network { + bootstrap = [ + "localhost:40001" + ], + timeout = 5s + max-packet-size = 512kB + client-id = null + } + blockchain { + consensus = "research-and-development" + default { + max-block-size = 1MB + view-timeout = 15s + } + research-and-development = ${metronome.blockchain.default} { + max-block-size = 10MB + } + main = ${metronome.blockchain.default} { + view-timeout = 5s + } + } + chain-id = test-chain +} diff --git a/metronome/config/test/resources/override.conf b/metronome/config/test/resources/override.conf new file mode 100644 index 00000000..ef9aeb70 --- /dev/null +++ b/metronome/config/test/resources/override.conf @@ -0,0 +1,22 @@ +override { + metrics { + enabled = false + } + network { + bootstrap = [ + "localhost:40001", + "localhost:40002" + ] + } + optional = null + numeric = 123 + textual = Hello World + boolean = true +} + +# Other setting that shouldn't be affected. +other { + metrics { + enabled = false + } +} diff --git a/metronome/config/test/resources/simple.conf b/metronome/config/test/resources/simple.conf new file mode 100644 index 00000000..d285e7b8 --- /dev/null +++ b/metronome/config/test/resources/simple.conf @@ -0,0 +1,11 @@ +# The root we are going to start from. +simple { + # Property name with a dash. + nested-structure { + foo = 10 + # Property name with an underscore. + bar_baz { + spam = eggs + } + } +} diff --git a/metronome/config/test/src/io/iohk/metronome/config/ConfigParserSpec.scala b/metronome/config/test/src/io/iohk/metronome/config/ConfigParserSpec.scala new file mode 100644 index 00000000..a1cf8fb0 --- /dev/null +++ b/metronome/config/test/src/io/iohk/metronome/config/ConfigParserSpec.scala @@ -0,0 +1,172 @@ +package io.iohk.metronome.config + +import com.typesafe.config.ConfigFactory +import io.circe.Decoder +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.Inside +import scala.concurrent.duration._ + +class ConfigParserSpec + extends AnyFlatSpec + with Matchers + with TableDrivenPropertyChecks + with Inside { + + "toJson" should "parse simple.conf to JSON" in { + val conf = ConfigFactory.load("simple.conf") + val json = ConfigParser.toJson(conf.getConfig("simple").root()) + json.noSpaces shouldBe """{"nested-structure":{"bar_baz":{"spam":"eggs"},"foo":10}}""" + } + + "toCamelCase" should "turn keys into camelCase" in { + val examples = Table( + ("input", "expected"), + ("nested-structure", "nestedStructure"), + ("nested_structure", "nestedStructure"), + ("multiple-dashes_and_underscores", "multipleDashesAndUnderscores"), + ("multiple-dashes_and_underscores", "multipleDashesAndUnderscores"), + ("camelCaseKey", "camelCaseKey") + ) + forAll(examples) { case (input, expected) => + ConfigParser.toCamelCase(input) shouldBe expected + } + } + + "toSnakeCase" should "turn camelCase keys into SNAKE_CASE" in { + val examples = Table( + ("input", "expected"), + ("nestedStructure", "NESTED_STRUCTURE"), + ("nested_structure", "NESTED_STRUCTURE"), + ("nested-structure", "NESTED_STRUCTURE") + ) + forAll(examples) { case (input, expected) => + ConfigParser.toSnakeCase(input) shouldBe expected + } + } + + "withCamelCase" should "turn all keys in a JSON object into camelCase" in { + val conf = ConfigFactory.load("simple.conf") + val orig = ConfigParser.toJson(conf.root()) + val json = (ConfigParser.withCamelCase(orig) \\ "simple").head + json.noSpaces shouldBe """{"nestedStructure":{"barBaz":{"spam":"eggs"},"foo":10}}""" + } + + "withEnvVarOverrides" should "overwrite keys from the environment" in { + val conf = ConfigFactory.load("override.conf") + val orig = ConfigParser.toJson(conf.getConfig("override").root()) + val json = ConfigParser.withCamelCase(orig) + + val env = Map( + "TEST_METRICS_ENABLED" -> "true", + "TEST_NETWORK_BOOTSTRAP_0" -> "localhost:50000", + "TEST_OPTIONAL" -> "test", + "TEST_NUMERIC" -> "456", + "TEST_TEXTUAL" -> "Terra Nostra", + "TEST_BOOLEAN" -> "false" + ) + + val result = ConfigParser.withEnvVarOverrides(json, "TEST", env) + + inside(result) { case Right(json) => + json.noSpaces shouldBe """{"boolean":false,"metrics":{"enabled":true},"network":{"bootstrap":["localhost:50000","localhost:40002"]},"numeric":456,"optional":"test","textual":"Terra Nostra"}""" + } + } + + it should "validate that data types are not altered" in { + val conf = ConfigFactory.load("override.conf") + val orig = ConfigParser.toJson(conf.root()) + val json = ConfigParser.withCamelCase(orig) + + val examples = Table( + ("path", "invalid"), + ("OVERRIDE_NUMERIC", "NaN"), + ("OVERRIDE_TEXTUAL", "123"), + ("OVERRIDE_BOOLEAN", "no") + ) + forAll(examples) { case (path, invalid) => + ConfigParser + .withEnvVarOverrides(json, "", Map(path -> invalid)) + .isLeft shouldBe true + } + } + + "parse" should "decode into a configuration model" in { + import ConfigParserSpec.TestConfig + + val config = ConfigParser.parse[TestConfig]( + ConfigFactory.load("complex.conf").getConfig("metronome").root(), + prefix = "TEST", + env = Map("TEST_METRICS_ENABLED" -> "true") + ) + + inside(config) { case Right(config) => + config shouldBe TestConfig( + TestConfig.Metrics(enabled = true), + TestConfig.Network( + bootstrap = List("localhost:40001"), + timeout = 5.seconds, + maxPacketSize = TestConfig.Size(512000), + clientId = None + ), + TestConfig + .Blockchain( + maxBlockSize = TestConfig.Size(10000000), + viewTimeout = 15.seconds + ), + chainId = Some("test-chain") + ) + } + } +} + +object ConfigParserSpec { + import io.circe._, io.circe.generic.semiauto._ + + case class TestConfig( + metrics: TestConfig.Metrics, + network: TestConfig.Network, + blockchain: TestConfig.Blockchain, + chainId: Option[String] + ) + object TestConfig { + implicit val durationDecoder: Decoder[FiniteDuration] = + ConfigDecoders.durationDecoder + + case class Metrics(enabled: Boolean) + object Metrics { + implicit val decoder: Decoder[Metrics] = + deriveDecoder + } + + case class Network( + bootstrap: List[String], + timeout: FiniteDuration, + maxPacketSize: Size, + clientId: Option[String] + ) + object Network { + implicit val decoder: Decoder[Network] = + deriveDecoder + } + + case class Size(bytes: Long) + object Size { + implicit val decoder: Decoder[Size] = + ConfigDecoders.bytesDecoder.map(Size(_)) + } + + case class Blockchain( + maxBlockSize: Size, + viewTimeout: FiniteDuration + ) + object Blockchain { + implicit val decoder: Decoder[Blockchain] = + ConfigDecoders.strategyDecoder[Blockchain]("consensus", deriveDecoder) + } + + implicit val decoder: Decoder[TestConfig] = + deriveDecoder + } +} diff --git a/metronome/core/src/io/iohk/metronome/core/Pipe.scala b/metronome/core/src/io/iohk/metronome/core/Pipe.scala new file mode 100644 index 00000000..53eed942 --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/Pipe.scala @@ -0,0 +1,50 @@ +package io.iohk.metronome.core + +import cats.implicits._ +import cats.effect.{Concurrent, ContextShift, Sync} +import monix.tail.Iterant +import monix.catnap.ConcurrentQueue + +/** A `Pipe` is a connection between two components where + * messages of type `L` are going from left to right and + * messages of type `R` are going from right to left. + */ +trait Pipe[F[_], L, R] { + type Left = Pipe.Side[F, L, R] + type Right = Pipe.Side[F, R, L] + + def left: Left + def right: Right +} +object Pipe { + + /** One side of a `Pipe` with + * messages of type `I` going in and + * messages of type `O` coming out. + */ + trait Side[F[_], I, O] { + def send(in: I): F[Unit] + def receive: Iterant[F, O] + } + object Side { + def apply[F[_]: Sync, I, O]( + iq: ConcurrentQueue[F, I], + oq: ConcurrentQueue[F, O] + ): Side[F, I, O] = new Side[F, I, O] { + override def send(in: I): F[Unit] = + iq.offer(in) + + override def receive: Iterant[F, O] = + Iterant.repeatEvalF(oq.poll) + } + } + + def apply[F[_]: Concurrent: ContextShift, L, R]: F[Pipe[F, L, R]] = + for { + lq <- ConcurrentQueue.unbounded[F, L](None) + rq <- ConcurrentQueue.unbounded[F, R](None) + } yield new Pipe[F, L, R] { + override val left = Side[F, L, R](lq, rq) + override val right = Side[F, R, L](rq, lq) + } +} diff --git a/metronome/core/src/io/iohk/metronome/core/Tagger.scala b/metronome/core/src/io/iohk/metronome/core/Tagger.scala new file mode 100644 index 00000000..da4a1f68 --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/Tagger.scala @@ -0,0 +1,37 @@ +package io.iohk.metronome.core + +import shapeless.tag, tag.@@ + +/** Helper class to make it easier to tag raw types such as BitVector + * to specializations so that the compiler can help make sure we are + * passign the right values to methods. + * + * ``` + * object MyType extends Tagger[ByteVector] + * type MyType = MyType.Tagged + * + * val myThing: MyType = MyType(ByteVector.empty) + * ``` + */ +trait Tagger[U] { + trait Tag + type Tagged = U @@ Tag + + def apply(underlying: U): Tagged = + tag[Tag][U](underlying) +} + +/** Helper class to tag not a specific raw type, but to apply a common tag to any type. + * + * ``` + * object Validated extends GenericTagger + * type Validated[U] = Validated.Tagged[U] + * ``` + */ +trait GenericTagger { + trait Tag + type Tagged[U] = U @@ Tag + + def apply[U](underlying: U): Tagged[U] = + tag[Tag][U](underlying) +} diff --git a/metronome/core/src/io/iohk/metronome/core/Validated.scala b/metronome/core/src/io/iohk/metronome/core/Validated.scala new file mode 100644 index 00000000..96c902f8 --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/Validated.scala @@ -0,0 +1,14 @@ +package io.iohk.metronome.core + +/** Can be used to tag any particular type as validated, for example: + * + * ``` + * def validateBlock(block: Block): Either[Error, Validated[Block]] + * def storeBlock(block: Validated[Block]) + * ``` + * + * It's a bit more lightweight than opting into the `ValidatedNel` from `cats`, + * mostly just serves as control that the right methods have been called in a + * pipeline. + */ +object Validated extends GenericTagger diff --git a/metronome/core/src/io/iohk/metronome/core/fibers/DeferredTask.scala b/metronome/core/src/io/iohk/metronome/core/fibers/DeferredTask.scala new file mode 100644 index 00000000..72cd0e99 --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/fibers/DeferredTask.scala @@ -0,0 +1,41 @@ +package io.iohk.metronome.core.fibers + +import cats.implicits._ +import cats.effect.Sync +import cats.effect.concurrent.Deferred +import cats.effect.Concurrent +import scala.util.control.NoStackTrace + +/** A task that can be executed on a fiber pool, or canceled if the pool is shut down.. */ +protected[fibers] class DeferredTask[F[_]: Sync, A]( + deferred: Deferred[F, Either[Throwable, A]], + task: F[A] +) { + import DeferredTask.CanceledException + + /** Execute the task and set the success/failure result on the deferred. */ + def execute: F[Unit] = + task.attempt.flatMap(deferred.complete) + + /** Get the result of the execution, raising an error if it failed. */ + def join: F[A] = + deferred.get.rethrow + + /** Signal to the submitter that this task is canceled. */ + def cancel: F[Unit] = + deferred + .complete(Left(new CanceledException)) + .attempt + .void +} + +object DeferredTask { + class CanceledException + extends RuntimeException("This task has been canceled.") + with NoStackTrace + + def apply[F[_]: Concurrent, A](task: F[A]): F[DeferredTask[F, A]] = + Deferred[F, Either[Throwable, A]].map { d => + new DeferredTask[F, A](d, task) + } +} diff --git a/metronome/core/src/io/iohk/metronome/core/fibers/FiberMap.scala b/metronome/core/src/io/iohk/metronome/core/fibers/FiberMap.scala new file mode 100644 index 00000000..9aedadcc --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/fibers/FiberMap.scala @@ -0,0 +1,170 @@ +package io.iohk.metronome.core.fibers + +import cats.implicits._ +import cats.effect.{Sync, Concurrent, ContextShift, Fiber, Resource} +import cats.effect.concurrent.{Ref, Semaphore} +import monix.catnap.ConcurrentQueue +import monix.execution.BufferCapacity +import monix.execution.ChannelType +import scala.util.control.NoStackTrace + +/** Execute tasks on a separate fiber per source key, + * facilitating separate rate limiting and fair concurrency. + * + * Each fiber executes tasks one by one. + */ +class FiberMap[F[_]: Concurrent: ContextShift, K]( + isShutdownRef: Ref[F, Boolean], + actorMapRef: Ref[F, Map[K, FiberMap.Actor[F]]], + semaphore: Semaphore[F], + capacity: BufferCapacity +) { + + /** Submit a task to be processed in the background. + * + * Create a new fiber if the given key hasn't got one yet. + * + * The result can be waited upon or discarded, the processing + * will happen in the background regardless. + */ + def submit[A](key: K)(task: F[A]): F[F[A]] = { + isShutdownRef.get.flatMap { + case true => + Sync[F].raiseError(new FiberMap.ShutdownException) + + case false => + actorMapRef.get.map(_.get(key)).flatMap { + case Some(actor) => + actor.submit(task) + case None => + semaphore.withPermit { + actorMapRef.get.map(_.get(key)).flatMap { + case Some(actor) => + actor.submit(task) + case None => + for { + actor <- FiberMap.Actor[F](capacity) + _ <- actorMapRef.update( + _.updated(key, actor) + ) + join <- actor.submit(task) + } yield join + } + } + } + } + } + + /** Cancel all enqueued tasks for a key. */ + def cancelQueue(key: K): F[Unit] = + actorMapRef.get.map(_.get(key)).flatMap { + case Some(actor) => actor.cancelQueue + case None => ().pure[F] + } + + /** Cancel all existing background processors. */ + private def shutdown: F[Unit] = { + semaphore.withPermit { + for { + _ <- isShutdownRef.set(true) + actorMap <- actorMapRef.get + _ <- actorMap.values.toList.traverse(_.shutdown) + } yield () + } + } +} + +object FiberMap { + + /** The queue of a key is at capacity and didn't accept the task. */ + class QueueFullException + extends RuntimeException("The fiber task queue is full.") + with NoStackTrace + + class ShutdownException + extends IllegalStateException("The pool is already shut down.") + + private class Actor[F[_]: Concurrent]( + queue: ConcurrentQueue[F, DeferredTask[F, _]], + runningRef: Ref[F, Option[DeferredTask[F, _]]], + fiber: Fiber[F, Unit] + ) { + + private val reject = Sync[F].raiseError[Unit](new QueueFullException) + + /** Submit a task to the queue, to be processed by the fiber. + * + * If the queue is full, a `QueueFullException` is raised so the submitting + * process knows that this key is producing too much data. + */ + def submit[A](task: F[A]): F[F[A]] = + for { + wrapper <- DeferredTask[F, A](task) + enqueued <- queue.tryOffer(wrapper) + _ <- reject.whenA(!enqueued) + } yield wrapper.join + + /** Cancel all enqueued tasks. */ + def cancelQueue: F[Unit] = + for { + tasks <- queue.drain(0, Int.MaxValue) + _ <- tasks.toList.traverse(_.cancel) + } yield () + + /** Cancel the processing and signal to all enqueued tasks that they will not be executed. */ + def shutdown: F[Unit] = + for { + _ <- fiber.cancel + maybeRunning <- runningRef.get + _ <- maybeRunning.fold(().pure[F])(_.cancel) + tasks <- cancelQueue + } yield () + } + private object Actor { + + /** Execute all tasks in the queue. */ + def process[F[_]: Sync]( + queue: ConcurrentQueue[F, DeferredTask[F, _]], + runningRef: Ref[F, Option[DeferredTask[F, _]]] + ): F[Unit] = + queue.poll.flatMap { task => + for { + _ <- runningRef.set(task.some) + _ <- task.execute + _ <- runningRef.set(none) + } yield () + } >> process(queue, runningRef) + + /** Create an actor and start executing tasks in the background. */ + def apply[F[_]: Concurrent: ContextShift]( + capacity: BufferCapacity + ): F[Actor[F]] = + for { + queue <- ConcurrentQueue + .withConfig[F, DeferredTask[F, _]](capacity, ChannelType.MPSC) + runningRef <- Ref[F].of(none[DeferredTask[F, _]]) + fiber <- Concurrent[F].start(process(queue, runningRef)) + } yield new Actor[F](queue, runningRef, fiber) + } + + /** Create an empty fiber pool. Cancel all fibers when it's released. */ + def apply[F[_]: Concurrent: ContextShift, K]( + capacity: BufferCapacity = BufferCapacity.Unbounded(None) + ): Resource[F, FiberMap[F, K]] = + Resource.make(build[F, K](capacity))(_.shutdown) + + private def build[F[_]: Concurrent: ContextShift, K]( + capacity: BufferCapacity + ): F[FiberMap[F, K]] = + for { + isShutdownRef <- Ref[F].of(false) + actorMapRef <- Ref[F].of(Map.empty[K, Actor[F]]) + semaphore <- Semaphore[F](1) + pool = new FiberMap[F, K]( + isShutdownRef, + actorMapRef, + semaphore, + capacity + ) + } yield pool +} diff --git a/metronome/core/src/io/iohk/metronome/core/fibers/FiberSet.scala b/metronome/core/src/io/iohk/metronome/core/fibers/FiberSet.scala new file mode 100644 index 00000000..2598156c --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/fibers/FiberSet.scala @@ -0,0 +1,69 @@ +package io.iohk.metronome.core.fibers + +import cats.implicits._ +import cats.effect.{Concurrent, Fiber, Resource} +import cats.effect.concurrent.{Ref, Deferred} + +/** Execute tasks in the background, canceling all fibers if the resource is released. + * + * Facilitates structured concurrency where the release of the component that submitted + * these fibers causes the cancelation of all of its scheduled tasks. + */ +class FiberSet[F[_]: Concurrent]( + isShutdownRef: Ref[F, Boolean], + fibersRef: Ref[F, Set[Fiber[F, Unit]]], + tasksRef: Ref[F, Set[DeferredTask[F, _]]] +) { + private def raiseIfShutdown: F[Unit] = + isShutdownRef.get.ifM( + Concurrent[F].raiseError(new FiberSet.ShutdownException), + ().pure[F] + ) + + def submit[A](task: F[A]): F[F[A]] = for { + _ <- raiseIfShutdown + deferredFiber <- Deferred[F, Fiber[F, Unit]] + + // Run the task, then remove the fiber from the tracker. + background: F[A] = for { + exec <- task.attempt + fiber <- deferredFiber.get + _ <- fibersRef.update(_ - fiber) + result <- Concurrent[F].delay(exec).rethrow + } yield result + + wrapper <- DeferredTask[F, A](background) + _ <- tasksRef.update(_ + wrapper) + + // Start running in the background. Only now do we know the identity of the fiber. + fiber <- Concurrent[F].start(wrapper.execute) + + // Add the fiber to the collectin first, so that if the effect is + // already finished, it gets to remove it and we're not leaking memory. + _ <- fibersRef.update(_ + fiber) + _ <- deferredFiber.complete(fiber) + + } yield wrapper.join + + def shutdown: F[Unit] = for { + _ <- isShutdownRef.set(true) + fibers <- fibersRef.get + _ <- fibers.toList.traverse(_.cancel) + tasks <- tasksRef.get + _ <- tasks.toList.traverse(_.cancel) + } yield () +} + +object FiberSet { + class ShutdownException + extends IllegalStateException("The pool is already shut down.") + + def apply[F[_]: Concurrent]: Resource[F, FiberSet[F]] = + Resource.make[F, FiberSet[F]] { + for { + isShutdownRef <- Ref[F].of(false) + fibersRef <- Ref[F].of(Set.empty[Fiber[F, Unit]]) + tasksRef <- Ref[F].of(Set.empty[DeferredTask[F, _]]) + } yield new FiberSet[F](isShutdownRef, fibersRef, tasksRef) + }(_.shutdown) +} diff --git a/metronome/core/src/io/iohk/metronome/core/messages/RPCMessage.scala b/metronome/core/src/io/iohk/metronome/core/messages/RPCMessage.scala new file mode 100644 index 00000000..3e7d9ddf --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/messages/RPCMessage.scala @@ -0,0 +1,48 @@ +package io.iohk.metronome.core.messages + +import cats.effect.Sync +import java.util.UUID + +/** Messages that go in request/response pairs. */ +trait RPCMessage { + + /** Unique identifier for request, which is expected to be + * included in the response message that comes back. + */ + def requestId: UUID +} + +abstract class RPCMessageCompanion { + type RequestId = UUID + + object RequestId { + def apply(): RequestId = + UUID.randomUUID() + + def apply[F[_]: Sync]: F[RequestId] = + Sync[F].delay(apply()) + } + + trait Request extends RPCMessage + trait Response extends RPCMessage + + /** Establish a relationship between a request and a response + * type so the compiler can infer the return value of methods + * based on the request parameter, or validate that two generic + * parameters belong with each other. + */ + def pair[A <: Request, B <: Response]: RPCPair.Aux[A, B] = + new RPCPair[A] { type Response = B } +} + +/** A request can be associated with at most one response type. + * On the other hand a response type can serve multiple requests. + */ +trait RPCPair[Request] { + type Response +} +object RPCPair { + type Aux[A, B] = RPCPair[A] { + type Response = B + } +} diff --git a/metronome/core/src/io/iohk/metronome/core/messages/RPCTracker.scala b/metronome/core/src/io/iohk/metronome/core/messages/RPCTracker.scala new file mode 100644 index 00000000..45200889 --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/messages/RPCTracker.scala @@ -0,0 +1,117 @@ +package io.iohk.metronome.core.messages + +import cats.implicits._ +import cats.effect.{Concurrent, Timer, Sync} +import cats.effect.concurrent.{Ref, Deferred} +import java.util.UUID +import scala.concurrent.duration.FiniteDuration +import scala.reflect.ClassTag + +/** `RPCTracker` can be used to register outgoing requests and later + * match them up with incoming responses, thus it facilitates turning + * the two independent messages into a `Kleisli[F, Request, Option[Response]]`, + * by a component that has access to the network, where a `None` result means + * the operation timed out before a response was received. + * + * The workflow is: + * 0. Receive some request parameters in a method. + * 1. Create a request ID. + * 2. Create a request with the ID. + * 3. Register the request with the tracker, hold on to the handle. + * 4. Send the request over the network. + * 5. Wait on the handle, eventually returning the optional result to the caller. + * 6. Pass every response received from the network to the tracker (on the network handler fiber). + */ +class RPCTracker[F[_]: Timer: Concurrent, M]( + deferredMapRef: Ref[F, Map[UUID, RPCTracker.Entry[F, _]]], + defaultTimeout: FiniteDuration +) { + import RPCTracker.Entry + + def register[ + Req <: RPCMessageCompanion#Request, + Res <: RPCMessageCompanion#Response + ]( + request: Req, + timeout: FiniteDuration = defaultTimeout + )(implicit + ev1: Req <:< M, + ev2: RPCPair.Aux[Req, Res], + // Used by `RPCTracker.Entry.complete` to make sure only the + // expected response type can complete a request. + ct: ClassTag[Res] + ): F[F[Option[Res]]] = { + val requestId = request.requestId + for { + d <- Deferred[F, Option[Res]] + e = RPCTracker.Entry(d) + _ <- deferredMapRef.update(_ + (requestId -> e)) + _ <- Concurrent[F].start { + Timer[F].sleep(timeout) >> completeWithTimeout(requestId) + } + } yield d.get + } + + /** Try to complete an outstanding request with a response. + * + * Returns `true` if the response was expected, `false` if + * it wasn't, or already timed out. An error is returned + * if the response was expected but the there was a type + * mismatch. + */ + def complete[Res <: RPCMessageCompanion#Response]( + response: Res + )(implicit ev: Res <:< M): F[Either[Throwable, Boolean]] = { + remove(response.requestId).flatMap { + case None => false.asRight[Throwable].pure[F] + case Some(e) => e.complete(response) + } + } + + private def completeWithTimeout(requestId: UUID): F[Unit] = + remove(requestId).flatMap { + case None => ().pure[F] + case Some(e) => e.timeout + } + + private def remove(requestId: UUID): F[Option[Entry[F, _]]] = + deferredMapRef.modify { dm => + (dm - requestId, dm.get(requestId)) + } +} + +object RPCTracker { + case class Entry[F[_]: Sync, Res]( + deferred: Deferred[F, Option[Res]] + )(implicit ct: ClassTag[Res]) { + def timeout: F[Unit] = + deferred.complete(None).attempt.void + + def complete[M](response: M): F[Either[Throwable, Boolean]] = { + response match { + case expected: Res => + deferred + .complete(Some(expected)) + .attempt + .map(_.isRight.asRight[Throwable]) + case _ => + // Wrong type, as evidenced by `ct` not maching `Res`. + // Returning an error so that this kind of programming error + // can be highlighted as soon as possible. Note though that + // if the request already timed out we can't tell if this + // error would have happened if the response arrived earlier. + val error = new IllegalArgumentException( + s"Invalid response type ${response.getClass.getName}; expected ${ct.runtimeClass.getName}" + ) + deferred.complete(None).attempt.as(error.asLeft[Boolean]) + } + } + } + + def apply[F[_]: Concurrent: Timer, M]( + defaultTimeout: FiniteDuration + ): F[RPCTracker[F, M]] = + Ref[F].of(Map.empty[UUID, Entry[F, _]]).map { + new RPCTracker(_, defaultTimeout) + } +} diff --git a/metronome/core/src/io/iohk/metronome/core/package.scala b/metronome/core/src/io/iohk/metronome/core/package.scala new file mode 100644 index 00000000..7b803e5a --- /dev/null +++ b/metronome/core/src/io/iohk/metronome/core/package.scala @@ -0,0 +1,5 @@ +package io.iohk.metronome + +package object core { + type Validated[U] = Validated.Tagged[U] +} diff --git a/metronome/core/test/src/io/iohk/metronome/core/PipeSpec.scala b/metronome/core/test/src/io/iohk/metronome/core/PipeSpec.scala new file mode 100644 index 00000000..5f24205e --- /dev/null +++ b/metronome/core/test/src/io/iohk/metronome/core/PipeSpec.scala @@ -0,0 +1,27 @@ +package io.iohk.metronome.core + +import org.scalatest.flatspec.AsyncFlatSpec +import monix.eval.Task +import monix.execution.Scheduler.Implicits.global +import org.scalatest.matchers.should.Matchers + +class PipeSpec extends AsyncFlatSpec with Matchers { + + behavior of "Pipe" + + it should "send messages between the sides" in { + val test = for { + pipe <- Pipe[Task, String, Int] + _ <- pipe.left.send("foo") + _ <- pipe.left.send("bar") + _ <- pipe.right.send(1) + rs <- pipe.right.receive.take(2).toListL + ls <- pipe.left.receive.headOptionL + } yield { + rs shouldBe List("foo", "bar") + ls shouldBe Some(1) + } + + test.runToFuture + } +} diff --git a/metronome/core/test/src/io/iohk/metronome/core/fibers/FiberMapSpec.scala b/metronome/core/test/src/io/iohk/metronome/core/fibers/FiberMapSpec.scala new file mode 100644 index 00000000..208dd46a --- /dev/null +++ b/metronome/core/test/src/io/iohk/metronome/core/fibers/FiberMapSpec.scala @@ -0,0 +1,173 @@ +package io.iohk.metronome.core.fibers + +import cats.effect.concurrent.Ref +import monix.eval.Task +import monix.execution.atomic.AtomicInt +import monix.execution.Scheduler.Implicits.global +import org.scalatest.{Inspectors, Inside} +import org.scalatest.compatible.Assertion +import org.scalatest.flatspec.AsyncFlatSpec +import org.scalatest.matchers.should.Matchers +import scala.util.Random +import scala.concurrent.duration._ +import monix.execution.BufferCapacity + +class FiberMapSpec extends AsyncFlatSpec with Matchers with Inside { + + def test(t: Task[Assertion]) = + t.timeout(10.seconds).runToFuture + + def testMap(f: FiberMap[Task, String] => Task[Assertion]) = test { + FiberMap[Task, String]().use(f) + } + + behavior of "FiberMap" + + it should "process tasks in the order they are submitted" in testMap { + fiberMap => + val stateRef = Ref.unsafe[Task, Map[String, Vector[Int]]](Map.empty) + + val keys = List("a", "b", "c") + + val valueMap = keys.map { + _ -> Random.shuffle(Range(0, 10).toVector) + }.toMap + + val tasks = for { + k <- keys + v <- valueMap(k) + } yield (k, v) + + def append(k: String, v: Int): Task[Unit] = + stateRef.update { state => + state.updated(k, state.getOrElse(k, Vector.empty) :+ v) + } + + for { + handles <- Task.traverse(tasks) { case (k, v) => + // This is a version that wouldn't preserve the order: + // append(k, v).start.map(_.join) + fiberMap.submit(k)(append(k, v)) + } + _ <- Task.parTraverse(handles)(identity) + state <- stateRef.get + } yield { + Inspectors.forAll(keys) { k => + state(k) shouldBe valueMap(k) + } + } + } + + it should "process tasks concurrently across keys" in testMap { fiberMap => + val running = AtomicInt(0) + val maxRunning = AtomicInt(0) + + val keys = List("a", "b") + val tasks = List.fill(10)(keys).flatten + + for { + handles <- Task.traverse(tasks) { k => + val task = for { + r <- Task(running.incrementAndGet()) + _ <- Task(maxRunning.getAndTransform(m => math.max(m, r))) + _ <- Task.sleep(20.millis) // Increase chance for overlap. + _ <- Task(running.decrement()) + } yield () + + fiberMap.submit(k)(task) + } + _ <- Task.parTraverse(handles)(identity) + } yield { + running.get() shouldBe 0 + maxRunning.get() shouldBe keys.size + } + } + + it should "return a value we can wait on" in testMap { fiberMap => + for { + task <- fiberMap.submit("foo")(Task("spam")) + value <- task + } yield { + value shouldBe "spam" + } + } + + it should "reject new submissions after shutdown" in test { + FiberMap[Task, String]().allocated.flatMap { case (fiberMap, release) => + for { + _ <- fiberMap.submit("foo")(Task("alpha")) + _ <- release + r <- fiberMap.submit("foo")(Task(2)).attempt + } yield { + inside(r) { case Left(ex) => + ex shouldBe a[IllegalStateException] + ex.getMessage should include("shut down") + } + } + } + } + + it should "reject new submissions for keys that hit their capacity limit" in test { + FiberMap[Task, String](BufferCapacity.Bounded(capacity = 1)).use { + fiberMap => + def trySubmit(k: String) = + fiberMap.submit(k)(Task.never).attempt + + for { + _ <- trySubmit("foo") + _ <- trySubmit("foo") + r3 <- trySubmit("foo") + r4 <- trySubmit("bar") + } yield { + inside(r3) { case Left(ex) => + ex shouldBe a[FiberMap.QueueFullException] + } + r4.isRight shouldBe true + } + } + } + + it should "cancel and raise errors in already submitted tasks after shutdown" in test { + FiberMap[Task, String]().allocated.flatMap { case (fiberMap, release) => + for { + r <- fiberMap.submit("foo")(Task.never) + _ <- release + r <- r.attempt + } yield { + inside(r) { case Left(ex) => + ex shouldBe a[DeferredTask.CanceledException] + } + } + } + } + + it should "cancel and raise errors in a canceled task" in testMap { + fiberMap => + for { + _ <- fiberMap.submit("foo")(Task.never) + r <- fiberMap.submit("foo")(Task("easy")) + _ <- fiberMap.cancelQueue("foo") + r <- r.attempt + } yield { + inside(r) { case Left(ex) => + ex shouldBe a[DeferredTask.CanceledException] + } + } + } + + it should "keep processing even if a task fails" in testMap { fiberMap => + for { + t1 <- fiberMap.submit("foo")( + Task.raiseError(new RuntimeException("Boom!")) + ) + t2 <- fiberMap.submit("foo")(Task(2)) + r1 <- t1.attempt + r2 <- t2 + } yield { + inside(r1) { case Left(ex) => + ex.getMessage shouldBe "Boom!" + } + r2 shouldBe 2 + } + } +} diff --git a/metronome/core/test/src/io/iohk/metronome/core/fibers/FiberSetSpec.scala b/metronome/core/test/src/io/iohk/metronome/core/fibers/FiberSetSpec.scala new file mode 100644 index 00000000..ab4205f4 --- /dev/null +++ b/metronome/core/test/src/io/iohk/metronome/core/fibers/FiberSetSpec.scala @@ -0,0 +1,82 @@ +package io.iohk.metronome.core.fibers + +import monix.eval.Task +import monix.execution.Scheduler.Implicits.global +import monix.execution.atomic.AtomicInt +import org.scalatest.compatible.Assertion +import org.scalatest.flatspec.AsyncFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.Inside +import scala.concurrent.duration._ + +class FiberSetSpec extends AsyncFlatSpec with Matchers with Inside { + + def test(t: Task[Assertion]) = + t.timeout(10.seconds).runToFuture + + behavior of "FiberSet" + + it should "reject new submissions after shutdown" in test { + FiberSet[Task].allocated.flatMap { case (fiberSet, release) => + for { + _ <- fiberSet.submit(Task("foo")) + _ <- release + r <- fiberSet.submit(Task("bar")).attempt + } yield { + inside(r) { case Left(ex) => + ex shouldBe a[IllegalStateException] + ex.getMessage should include("shut down") + } + } + } + } + + it should "cancel and raise errors in already submitted tasks after shutdown" in test { + FiberSet[Task].allocated.flatMap { case (fiberSet, release) => + for { + r <- fiberSet.submit(Task.never) + _ <- release + r <- r.attempt + } yield { + inside(r) { case Left(ex) => + ex shouldBe a[DeferredTask.CanceledException] + } + } + } + } + + it should "return a value we can wait on" in test { + FiberSet[Task].use { fiberSet => + for { + task <- fiberSet.submit(Task("spam")) + value <- task + } yield { + value shouldBe "spam" + } + } + } + + it should "process tasks concurrently" in test { + FiberSet[Task].use { fiberSet => + val running = AtomicInt(0) + val maxRunning = AtomicInt(0) + + for { + handles <- Task.traverse(1 to 10) { _ => + val task = for { + r <- Task(running.incrementAndGet()) + _ <- Task(maxRunning.getAndTransform(m => math.max(m, r))) + _ <- Task.sleep(20.millis) // Increase chance for overlap. + _ <- Task(running.decrement()) + } yield () + + fiberSet.submit(task) + } + _ <- Task.parTraverse(handles)(identity) + } yield { + running.get() shouldBe 0 + maxRunning.get() should be > 1 + } + } + } +} diff --git a/metronome/core/test/src/io/iohk/metronome/core/messages/RPCTrackerSpec.scala b/metronome/core/test/src/io/iohk/metronome/core/messages/RPCTrackerSpec.scala new file mode 100644 index 00000000..1ac23f20 --- /dev/null +++ b/metronome/core/test/src/io/iohk/metronome/core/messages/RPCTrackerSpec.scala @@ -0,0 +1,85 @@ +package io.iohk.metronome.core.messages + +import monix.eval.Task +import monix.execution.Scheduler.Implicits.global +import org.scalatest.flatspec.AsyncFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.compatible.Assertion +import org.scalatest.Inside +import scala.concurrent.Future +import scala.concurrent.duration._ + +class RPCTrackerSpec extends AsyncFlatSpec with Matchers with Inside { + + sealed trait TestMessage extends RPCMessage + object TestMessage extends RPCMessageCompanion { + case class FooRequest(requestId: RequestId) extends TestMessage with Request + case class FooResponse(requestId: RequestId, value: Int) + extends TestMessage + with Response + case class BarRequest(requestId: RequestId) extends TestMessage with Request + case class BarResponse(requestId: RequestId, value: String) + extends TestMessage + with Response + + implicit val foo = pair[FooRequest, FooResponse] + implicit val bar = pair[BarRequest, BarResponse] + } + import TestMessage._ + + def test( + f: RPCTracker[Task, TestMessage] => Task[Assertion] + ): Future[Assertion] = + RPCTracker[Task, TestMessage](10.seconds) + .flatMap(f) + .timeout(5.seconds) + .runToFuture + + behavior of "RPCTracker" + + it should "complete responses within the timeout" in test { tracker => + val req = FooRequest(RequestId()) + val res = FooResponse(req.requestId, 1) + for { + join <- tracker.register(req) + ok <- tracker.complete(res) + got <- join + } yield { + ok shouldBe Right(true) + got shouldBe Some(res) + } + } + + it should "complete responses with None after the timeout" in test { + tracker => + val req = FooRequest(RequestId()) + val res = FooResponse(req.requestId, 1) + for { + join <- tracker.register(req, timeout = 50.millis) + _ <- Task.sleep(100.millis) + ok <- tracker.complete(res) + got <- join + } yield { + ok shouldBe Right(false) + got shouldBe empty + } + } + + it should "complete responses with None if the wrong type of response arrives" in test { + tracker => + for { + rid <- RequestId[Task] + req = FooRequest(rid) + res = BarResponse(rid, "one") + join <- tracker.register(req) + ok <- tracker.complete(res) + got <- join + } yield { + inside(ok) { case Left(error) => + error.getMessage should include("Invalid response type") + } + got shouldBe empty + } + } + +} diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/ECKeyPair.scala b/metronome/crypto/src/io/iohk/metronome/crypto/ECKeyPair.scala new file mode 100644 index 00000000..0aaff420 --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/ECKeyPair.scala @@ -0,0 +1,28 @@ +package io.iohk.metronome.crypto + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair + +import java.security.SecureRandom + +/** The pair of EC private and public keys for Secp256k1 elliptic curve */ +case class ECKeyPair(prv: ECPrivateKey, pub: ECPublicKey) { + + /** The bouncycastle's underlying type for efficient use with + * `io.iohk.ethereum.crypto.ECDSASignature` + */ + def underlying: AsymmetricCipherKeyPair = prv.underlying +} + +object ECKeyPair { + + def apply(keyPair: AsymmetricCipherKeyPair): ECKeyPair = { + val (prv, pub) = io.iohk.ethereum.crypto.keyPairToByteArrays(keyPair) + ECKeyPair(ECPrivateKey(prv), ECPublicKey(pub)) + } + + /** Generates a new keypair on the Secp256k1 elliptic curve */ + def generate(secureRandom: SecureRandom): ECKeyPair = { + val kp = io.iohk.ethereum.crypto.generateKeyPair(secureRandom) + ECKeyPair(kp) + } +} diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/ECPrivateKey.scala b/metronome/crypto/src/io/iohk/metronome/crypto/ECPrivateKey.scala new file mode 100644 index 00000000..3e0102f7 --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/ECPrivateKey.scala @@ -0,0 +1,27 @@ +package io.iohk.metronome.crypto + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair +import scodec.bits.ByteVector +import io.iohk.ethereum.crypto.keyPairFromPrvKey + +/** Wraps the bytes representing an EC private key */ +case class ECPrivateKey(bytes: ByteVector) { + require( + bytes.length == ECPrivateKey.Length, + s"Key must be ${ECPrivateKey.Length} bytes long" + ) + + /** Converts the byte representation to bouncycastle's `AsymmetricCipherKeyPair` for efficient use with + * `io.iohk.ethereum.crypto.ECDSASignature` + */ + val underlying: AsymmetricCipherKeyPair = keyPairFromPrvKey( + bytes.toArray + ) +} + +object ECPrivateKey { + val Length = 32 + + def apply(bytes: Array[Byte]): ECPrivateKey = + ECPrivateKey(ByteVector(bytes)) +} diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/ECPublicKey.scala b/metronome/crypto/src/io/iohk/metronome/crypto/ECPublicKey.scala new file mode 100644 index 00000000..5036fb1d --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/ECPublicKey.scala @@ -0,0 +1,22 @@ +package io.iohk.metronome.crypto + +import scodec.Codec +import scodec.bits.ByteVector +import scodec.codecs.bytes + +/** Wraps the bytes representing an EC public key in uncompressed format and without the compression indicator */ +case class ECPublicKey(bytes: ByteVector) { + require( + bytes.length == ECPublicKey.Length, + s"Key must be ${ECPublicKey.Length} bytes long" + ) +} + +object ECPublicKey { + val Length = 64 + + def apply(bytes: Array[Byte]): ECPublicKey = + ECPublicKey(ByteVector(bytes)) + + implicit val codec: Codec[ECPublicKey] = bytes.as[ECPublicKey] +} diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/GroupSignature.scala b/metronome/crypto/src/io/iohk/metronome/crypto/GroupSignature.scala new file mode 100644 index 00000000..a453c135 --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/GroupSignature.scala @@ -0,0 +1,7 @@ +package io.iohk.metronome.crypto + +/** Group signature of members with identity `K` over some content `H`, + * represented by type `G`, e.g. `G` could be a `List[Secp256k1Signature]` + * or a single combined threshold signature of some sort. + */ +case class GroupSignature[K, H, G](sig: G) diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/PartialSignature.scala b/metronome/crypto/src/io/iohk/metronome/crypto/PartialSignature.scala new file mode 100644 index 00000000..de6eeeef --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/PartialSignature.scala @@ -0,0 +1,7 @@ +package io.iohk.metronome.crypto + +/** An individual signature of a member with identity `K` over some content `H`, + * represented by type `P`, e.g. `P` could be a single `Secp256k1Signature` + * or a partial threshold signature of some sort. + */ +case class PartialSignature[K, H, P](sig: P) diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/hash/Hash.scala b/metronome/crypto/src/io/iohk/metronome/crypto/hash/Hash.scala new file mode 100644 index 00000000..b4d0df14 --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/hash/Hash.scala @@ -0,0 +1,6 @@ +package io.iohk.metronome.crypto.hash + +import io.iohk.metronome.core.Tagger +import scodec.bits.ByteVector + +object Hash extends Tagger[ByteVector] diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/hash/Keccak256.scala b/metronome/crypto/src/io/iohk/metronome/crypto/hash/Keccak256.scala new file mode 100644 index 00000000..b6e96a8b --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/hash/Keccak256.scala @@ -0,0 +1,20 @@ +package io.iohk.metronome.crypto.hash + +import org.bouncycastle.crypto.digests.KeccakDigest +import scodec.bits.{BitVector, ByteVector} + +object Keccak256 { + def apply(data: Array[Byte]): Hash = { + val output = new Array[Byte](32) + val digest = new KeccakDigest(256) + digest.update(data, 0, data.length) + digest.doFinal(output, 0) + Hash(ByteVector(output)) + } + + def apply(data: ByteVector): Hash = + apply(data.toArray) + + def apply(data: BitVector): Hash = + apply(data.toByteArray) +} diff --git a/metronome/crypto/src/io/iohk/metronome/crypto/hash/package.scala b/metronome/crypto/src/io/iohk/metronome/crypto/hash/package.scala new file mode 100644 index 00000000..e9df1174 --- /dev/null +++ b/metronome/crypto/src/io/iohk/metronome/crypto/hash/package.scala @@ -0,0 +1,5 @@ +package io.iohk.metronome.crypto + +package object hash { + type Hash = Hash.Tagged +} diff --git a/metronome/crypto/test/src/io/iohk/metronome/crypto/hash/Keccak256Spec.scala b/metronome/crypto/test/src/io/iohk/metronome/crypto/hash/Keccak256Spec.scala new file mode 100644 index 00000000..ce87ac86 --- /dev/null +++ b/metronome/crypto/test/src/io/iohk/metronome/crypto/hash/Keccak256Spec.scala @@ -0,0 +1,21 @@ +package io.iohk.metronome.crypto.hash + +import scodec.bits._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class Keccak256Spec extends AnyFlatSpec with Matchers { + behavior of "Keccak256" + + it should "hash empty data" in { + Keccak256( + "".getBytes + ) shouldBe hex"c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470" + } + + it should "hash non-empty data" in { + Keccak256( + "abc".getBytes + ) shouldBe hex"4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45" + } +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/Federation.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/Federation.scala new file mode 100644 index 00000000..bec35752 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/Federation.scala @@ -0,0 +1,79 @@ +package io.iohk.metronome.hotstuff.consensus + +/** Collection of keys of the federation members. + * + * There are two inequalities that decide the quorum size `q`: + * + * 1.) Safety inequality: + * There should not be two conflicting quorums. + * If two quorums conflict, their intersection has a size of `2q-n`. + * The intersection represents equivocation and will have a size of + * at most `f` (since honest nodes don't equivocate). + * Thus for safety we need `2q-n > f` => `q > (n+f)/2` + * + * 2.) Liveness inequality: + * Quorum size should be small enough so that adversaries cannot deadlock + * the system by not voting. If the quorum size is greater than `n-f`, + * adversaries may decide to not vote and hence we will not have any quorum certificate. + * Thus, we need `q <= n-f` + * + * So any `q` between `(n+f)/2+1` and `n-f` should work. + * Smaller `q` is preferred as it would improve speed. + * We can set it to `(n+f)/2+1` or fix it to `2/3n+1`. + * + * Extra: The above two inequalities `(n+f)/2 < q <= n-f`, lead to the constraint: `f < n/3`, or `n >= 3*f+1`. + */ +abstract case class Federation[PKey] private ( + publicKeys: IndexedSeq[PKey], + // Maximum number of Byzantine nodes. + maxFaulty: Int +)(implicit ls: LeaderSelection) { + private val publicKeySet = publicKeys.toSet + + /** Size of the federation. */ + val size: Int = publicKeys.size + + /** Number of signatures required for a Quorum Certificate. */ + val quorumSize: Int = (size + maxFaulty) / 2 + 1 + + def contains(publicKey: PKey): Boolean = + publicKeySet.contains(publicKey) + + def leaderOf(viewNumber: ViewNumber): PKey = + publicKeys(implicitly[LeaderSelection].leaderOf(viewNumber, size)) +} + +object Federation { + + /** Create a federation with the highest possible fault tolerance. */ + def apply[PKey]( + publicKeys: IndexedSeq[PKey] + )(implicit ls: LeaderSelection): Either[String, Federation[PKey]] = + apply(publicKeys, maxByzantine(publicKeys.size)) + + /** Create a federation with the fault tolerance possibly reduced from the theoretical + * maximum, which can allow smaller quorum sizes, and improved speed. + * + * Returns an error if the configured value is higher than the theoretically tolerable maximum. + */ + def apply[PKey]( + publicKeys: IndexedSeq[PKey], + maxFaulty: Int + )(implicit ls: LeaderSelection): Either[String, Federation[PKey]] = { + val f = maxByzantine(publicKeys.size) + if (publicKeys.isEmpty) { + Left("The federation cannot be empty!") + } else if (publicKeys.distinct.size < publicKeys.size) { + Left("The keys in the federation must be unique!") + } else if (maxFaulty > f) { + Left( + s"The maximum tolerable number of Byzantine members is $f, less than the specified $maxFaulty." + ) + } else { + Right(new Federation(publicKeys, maxFaulty) {}) + } + } + + /** Maximum number of Byzantine nodes in a federation of size `n` */ + def maxByzantine(n: Int): Int = (n - 1) / 3 +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/LeaderSelection.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/LeaderSelection.scala new file mode 100644 index 00000000..a4940711 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/LeaderSelection.scala @@ -0,0 +1,53 @@ +package io.iohk.metronome.hotstuff.consensus + +import io.iohk.metronome.crypto.hash.Keccak256 +import scodec.bits.ByteVector + +/** Strategy to pick the leader for a given view number from + * federation of with a fixed size. + */ +trait LeaderSelection { + + /** Return the index of the federation member who should lead the view. */ + def leaderOf(viewNumber: ViewNumber, size: Int): Int +} + +object LeaderSelection { + + /** Simple strategy cycling through leaders in a static order. */ + object RoundRobin extends LeaderSelection { + override def leaderOf(viewNumber: ViewNumber, size: Int): Int = + (viewNumber % size).toInt + } + + /** Leader assignment based on view-number has not been discussed in the Hotstuff + * paper and in general, it does not affect the safety and liveness. + * However, it does affect worst-case latency. + * + * Consider a static adversary under a round-robin leader change scheme. + * All the f nodes can set their public keys so that they are consecutive. + * In such a scenario those f consecutive leaders can create timeouts leading + * to an O(f) confirmation latency. (Recall that in a normal case, the latency is O(1)). + * + * A minor improvement to this is to assign leaders based on + * "publicKeys((H256(viewNumber).toInt % size).toInt)". + * + * This leader order randomization via a hash function will ensure that even + * if adversarial public keys are consecutive in PublicKey set, they are not + * necessarily consecutive in leader order. + * + * Note that the above policy will not ensure that adversarial leaders are never consecutive, + * but the probability of such occurrence will be lower under a static adversary. + */ + object Hashing extends LeaderSelection { + override def leaderOf(viewNumber: ViewNumber, size: Int): Int = { + val bytes = ByteVector.fromLong(viewNumber) // big-endian + val hash = Keccak256(bytes) + // If we prepend 0.toByte then it would treat it as unsigned, at the cost of an array copy. + // Instead of doing that I'll just make sure we deal with negative modulo. + val num = BigInt(hash.toArray) + val mod = (num % size).toInt + if (mod < 0) mod + size else mod + } + } +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/ViewNumber.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/ViewNumber.scala new file mode 100644 index 00000000..941f7394 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/ViewNumber.scala @@ -0,0 +1,17 @@ +package io.iohk.metronome.hotstuff.consensus + +import io.iohk.metronome.core.Tagger +import cats.kernel.Order + +object ViewNumber extends Tagger[Long] { + implicit class Ops(val vn: ViewNumber) extends AnyVal { + def next: ViewNumber = ViewNumber(vn + 1) + def prev: ViewNumber = ViewNumber(vn - 1) + } + + implicit val ord: Ordering[ViewNumber] = + Ordering.by(identity[Long]) + + implicit val order: Order[ViewNumber] = + Order.fromOrdering(ord) +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Agreement.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Agreement.scala new file mode 100644 index 00000000..4361bf4e --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Agreement.scala @@ -0,0 +1,27 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +/** Capture all the generic types in the BFT agreement, + * so we don't have to commit to any particular set of content. + */ +trait Agreement { + + /** The container type that the agreement is about. */ + type Block + + /** The type we use for hashing blocks, + * so they don't have to be sent in entirety in votes. + */ + type Hash + + /** The concrete type that represents a partial signature. */ + type PSig + + /** The concrete type that represents a group signature. */ + type GSig + + /** The public key identity of federation members. */ + type PKey + + /** The secret key used for signing partial messages. */ + type SKey +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Block.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Block.scala new file mode 100644 index 00000000..70780d05 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Block.scala @@ -0,0 +1,31 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +/** Type class to project the properties we need a HotStuff block to have + * from the generic `Block` type in the `Agreement`. + * + * This allows the block to include use-case specific details HotStuff doesn't + * care about, for example to build up a ledger state that can be synchronised + * directly, rather than just carry out a sequence of commands on all replicas. + * This would require the blocks to contain ledger state hashes, which other + * use cases may have no use for. + */ +trait Block[A <: Agreement] { + def blockHash(b: A#Block): A#Hash + def parentBlockHash(b: A#Block): A#Hash + def height(b: A#Block): Long + + /** Perform simple content validation, e.g. + * whether the block hash matches the header + * and the header content matches the body. + */ + def isValid(b: A#Block): Boolean + + def isParentOf(parent: A#Block, child: A#Block): Boolean = { + parentBlockHash(child) == blockHash(parent) && + height(child) == height(parent) + 1 + } +} + +object Block { + def apply[A <: Agreement: Block]: Block[A] = implicitly[Block[A]] +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Effect.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Effect.scala new file mode 100644 index 00000000..a2737e8c --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Effect.scala @@ -0,0 +1,73 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import scala.concurrent.duration.FiniteDuration + +import io.iohk.metronome.hotstuff.consensus.ViewNumber + +/** Represent all possible effects that a protocol transition can + * ask the host system to carry out, e.g. send messages to replicas. + */ +sealed trait Effect[+A <: Agreement] + +object Effect { + + /** Schedule a callback after a timeout to initiate the next view + * if the current rounds ends without an agreement. + */ + case class ScheduleNextView( + viewNumber: ViewNumber, + timeout: FiniteDuration + ) extends Effect[Nothing] + + /** Send a message to a federation member. + * + * The recipient can be the current member itself (i.e. the leader + * sending itself a message to trigger its own vote). It is best + * if the host system carries out these effects before it talks + * to the external world, to avoid any possible phase mismatches. + * + * The `ProtocolState` could do it on its own but this way it's + * slightly closer to the pseudo code. + */ + case class SendMessage[A <: Agreement]( + recipient: A#PKey, + message: Message[A] + ) extends Effect[A] + + /** The leader of the round wants to propose a new block + * on top of the last prepared one. The host environment + * should consult the mempool and create one, passing the + * result as an event. + * + * The block must be built as a child of `highQC.blockHash`. + */ + case class CreateBlock[A <: Agreement]( + viewNumber: ViewNumber, + highQC: QuorumCertificate[A] + ) extends Effect[A] + + /** Once the Prepare Q.C. has been established for a block, + * we know that it's not spam, it's safe to be persisted. + * + * This prevents a rouge leader from sending us many `Prepare` + * messages in the same view with the intention of eating up + * space using the included block. + * + * It's also a way for us to delay saving a block we created + * as a leader to the time when it's been voted on. Since it's + * part of the `Prepare` message, replicas shouldn't be asking + * for it anyway, so it's not a problem if it's not yet persisted. + */ + case class SaveBlock[A <: Agreement]( + preparedBlock: A#Block + ) extends Effect[A] + + /** Execute blocks after a decision, from the last executed hash + * up to the block included in the Quorum Certificate. + */ + case class ExecuteBlocks[A <: Agreement]( + lastExecutedBlockHash: A#Hash, + quorumCertificate: QuorumCertificate[A] + ) extends Effect[A] + +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Event.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Event.scala new file mode 100644 index 00000000..5a03c359 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Event.scala @@ -0,0 +1,26 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.hotstuff.consensus.ViewNumber + +/** Input events for the protocol model. */ +sealed trait Event[+A <: Agreement] + +object Event { + + /** A scheduled timeout for the round, initiating the next view. */ + case class NextView(viewNumber: ViewNumber) extends Event[Nothing] + + /** A message received from a federation member. */ + case class MessageReceived[A <: Agreement]( + sender: A#PKey, + message: Message[A] + ) extends Event[A] + + /** The block the leader asked to be created is ready. */ + case class BlockCreated[A <: Agreement]( + viewNumber: ViewNumber, + block: A#Block, + // The certificate which the block extended. + highQC: QuorumCertificate[A] + ) extends Event[A] +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Message.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Message.scala new file mode 100644 index 00000000..9980905e --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Message.scala @@ -0,0 +1,69 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.crypto.PartialSignature +import io.iohk.metronome.hotstuff.consensus.ViewNumber + +/** Basic HotStuff protocol messages. */ +sealed trait Message[A <: Agreement] { + + /** Messages are only accepted if they match the node's current view number. */ + def viewNumber: ViewNumber +} + +/** Message from the leader to the replica. */ +sealed trait LeaderMessage[A <: Agreement] extends Message[A] + +/** Message from the replica to the leader. */ +sealed trait ReplicaMessage[A <: Agreement] extends Message[A] + +object Message { + + /** The leader proposes a new block in the `Prepare` phase, + * using the High Q.C. gathered from `NewView` messages. + */ + case class Prepare[A <: Agreement]( + viewNumber: ViewNumber, + block: A#Block, + highQC: QuorumCertificate[A] + ) extends LeaderMessage[A] + + /** Having received one of the leader messages, the replica + * casts its vote with its partical signature. + * + * The vote carries either the hash of the block, which + * was either received full in the `Prepare` message, + * or as part of a `QuorumCertificate`. + */ + case class Vote[A <: Agreement]( + viewNumber: ViewNumber, + phase: VotingPhase, + blockHash: A#Hash, + signature: PartialSignature[ + A#PKey, + (VotingPhase, ViewNumber, A#Hash), + A#PSig + ] + ) extends ReplicaMessage[A] + + /** Having collected enough votes from replicas, + * the leader combines the votes into a Q.C. and + * broadcasts it to replicas: + * - Prepare votes combine into a Prepare Q.C., expected in the PreCommit phase. + * - PreCommit votes combine into a PreCommit Q.C., expected in the Commit phase. + * - Commit votes combine into a Commit Q.C, expected in the Decide phase. + * + * The certificate contains the hash of the block to vote on. + */ + case class Quorum[A <: Agreement]( + viewNumber: ViewNumber, + quorumCertificate: QuorumCertificate[A] + ) extends LeaderMessage[A] + + /** At the end of the round, replicas send the `NewView` message + * to the next leader with the last Prepare Q.C. + */ + case class NewView[A <: Agreement]( + viewNumber: ViewNumber, + prepareQC: QuorumCertificate[A] + ) extends ReplicaMessage[A] +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Phase.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Phase.scala new file mode 100644 index 00000000..f7a83534 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Phase.scala @@ -0,0 +1,49 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +/** All phases of the basic HotStuff protocol. */ +sealed trait Phase { + import Phase._ + def next: Phase = + this match { + case Prepare => PreCommit + case PreCommit => Commit + case Commit => Decide + case Decide => Prepare + } + + def prev: Phase = + this match { + case Prepare => Decide + case PreCommit => Prepare + case Commit => PreCommit + case Decide => Commit + } + + /** Check that *within the same view* phase this phase precedes the other. */ + def isBefore(other: Phase): Boolean = + (this, other) match { + case (Prepare, PreCommit | Commit | Decide) => true + case (PreCommit, Commit | Decide) => true + case (Commit, Decide) => true + case _ => false + } + + /** Check that *within the same view* this phase follows the other. */ + def isAfter(other: Phase): Boolean = + (this, other) match { + case (PreCommit, Prepare) => true + case (Commit, Prepare | PreCommit) => true + case (Decide, Prepare | PreCommit | Commit) => true + case _ => false + } +} + +/** Subset of phases over which there can be vote and a Quorum Certificate. */ +sealed trait VotingPhase extends Phase + +object Phase { + case object Prepare extends VotingPhase + case object PreCommit extends VotingPhase + case object Commit extends VotingPhase + case object Decide extends Phase +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolError.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolError.scala new file mode 100644 index 00000000..08b437ed --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolError.scala @@ -0,0 +1,77 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.hotstuff.consensus.ViewNumber + +sealed trait ProtocolError[A <: Agreement] + +object ProtocolError { + + /** A leader message was received from a replica that isn't the leader of the view. */ + case class NotFromLeader[A <: Agreement]( + event: Event.MessageReceived[A], + expected: A#PKey + ) extends ProtocolError[A] + + /** A replica message was received in a view that this replica is not leading. */ + case class NotToLeader[A <: Agreement]( + event: Event.MessageReceived[A], + expected: A#PKey + ) extends ProtocolError[A] + + /** A message coming from outside the federation members. */ + case class NotFromFederation[A <: Agreement]( + event: Event.MessageReceived[A] + ) extends ProtocolError[A] + + /** The vote signature doesn't match the content. */ + case class InvalidVote[A <: Agreement]( + sender: A#PKey, + message: Message.Vote[A] + ) extends ProtocolError[A] + + /** The Q.C. signature doesn't match the content. */ + case class InvalidQuorumCertificate[A <: Agreement]( + sender: A#PKey, + quorumCertificate: QuorumCertificate[A] + ) extends ProtocolError[A] + + /** The block in the prepare message doesn't extend the previous Q.C. */ + case class UnsafeExtension[A <: Agreement]( + sender: A#PKey, + message: Message.Prepare[A] + ) extends ProtocolError[A] + + /** A message we didn't expect to receive in the given state. */ + case class UnexpectedBlockHash[A <: Agreement]( + event: Event.MessageReceived[A], + expected: A#Hash + ) extends ProtocolError[A] + + /** A message that we received slightly earlier than we expected. + * + * One reason for this could be that the peer is slightly ahead of us, + * e.g. already finished the `Decide` phase and sent out the `NewView` + * to us, the next leader, in which case the view number would not + * match up. Or maybe a quorum has already formed for the next round + * and we receive a `Prepare`, while we're still in `Decide`. + * + * The host system passing the events and processing the effects + * is expected to inspect `TooEarly` messages and decide what to do: + * - if the message is for the next round or next phase, then just re-deliver it after the view transition + * - if the message is far in the future, perhaps it's best to re-sync the status with everyone + */ + case class TooEarly[A <: Agreement]( + event: Event.MessageReceived[A], + expectedInViewNumber: ViewNumber, + expectedInPhase: Phase + ) extends ProtocolError[A] + + /** A message we didn't expect to receive in the given state. + * + * The host system can maintain some metrics so we can see if we're completely out of + * alignment with all the other peers. + */ + case class Unexpected[A <: Agreement]( + event: Event.MessageReceived[A] + ) extends ProtocolError[A] +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolState.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolState.scala new file mode 100644 index 00000000..cfe75eb7 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolState.scala @@ -0,0 +1,482 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.core.Validated +import io.iohk.metronome.hotstuff.consensus.{ViewNumber, Federation} +import scala.concurrent.duration.FiniteDuration + +/** Basic HotStuff protocol state machine. + * + * See https://arxiv.org/pdf/1803.05069.pdf + * + * ``` + * + * PHASE LEADER REPLICA + * | | + * | <--- NewView(prepareQC) ---- | + * # Prepare --------------- | | + * select highQC | | + * create block | | + * | ------ Prepare(block) -----> | + * | | check safety + * | <----- Vote(Prepare) ------- | + * # PreCommit ------------- | | + * | ------ Prepare Q.C. -------> | + * | | save as prepareQC + * | <----- Vote(PreCommit) ----- | + * # Commit ---------------- | | + * | ------ PreCommit Q.C. -----> | + * | | save as lockedQC + * | <----- Vote(Commit) -------- | + * # Decide ---------------- | | + * | ------ Commit Q.C. --------> | + * | | execute block + * | <--- NewView(prepareQC) ---- | + * | | + * + * ``` + */ +case class ProtocolState[A <: Agreement: Block: Signing]( + viewNumber: ViewNumber, + phase: Phase, + publicKey: A#PKey, + signingKey: A#SKey, + federation: Federation[A#PKey], + // Highest QC for which a replica voted Pre-Commit, because it received a Prepare Q.C. from the leader. + prepareQC: QuorumCertificate[A], + // Locked QC, for which a replica voted Commit, because it received a Pre-Commit Q.C. from leader. + lockedQC: QuorumCertificate[A], + // Commit QC, which a replica received in the Decide phase, and then executed the block in it. + commitQC: QuorumCertificate[A], + // The block the federation is currently voting on. + preparedBlock: A#Block, + // Timeout for the view, so that it can be adjusted next time if necessary. + timeout: FiniteDuration, + // Votes gathered by the leader in this phase. They are guarenteed to be over the same content. + votes: Set[Message.Vote[A]], + // NewView messages gathered by the leader during the Prepare phase. Map so every sender can only give one. + newViews: Map[A#PKey, Message.NewView[A]] +) { + import Message._ + import Effect._ + import Event._ + import ProtocolState._ + import ProtocolError._ + + val leader = federation.leaderOf(viewNumber) + val isLeader = leader == publicKey + + /** The leader has to collect `n-f` signatures into a Q.C. + * + * This value can be lower if we have higher trust in the federation. + */ + def quorumSize = federation.quorumSize + + /** Hash of the block that was last decided upon. */ + def lastExecutedBlockHash: A#Hash = commitQC.blockHash + + /** Hash of the block currently being voted on. */ + def preparedBlockHash: A#Hash = Block[A].blockHash(preparedBlock) + + /** No state transition. */ + private def stay: Transition[A] = + this -> Nil + + private def moveTo(phase: Phase): ProtocolState[A] = + copy( + viewNumber = if (phase == Phase.Prepare) viewNumber.next else viewNumber, + phase = phase, + votes = Set.empty, + newViews = Map.empty + ) + + /** The round has timed out; send `prepareQC` to the leader + * of the next view and move to that view now. + */ + def handleNextView(e: NextView): Transition[A] = + if (e.viewNumber == viewNumber) { + val next = moveTo(Phase.Prepare) + val effects = Seq( + SendMessage(next.leader, NewView(viewNumber, prepareQC)), + ScheduleNextView(next.viewNumber, next.timeout) + ) + next -> effects + } else stay + + /** A block we asked the host system to create using `Effect.CreateBlock` is + * ready to be broadcasted, if we're still in the same view. + */ + def handleBlockCreated(e: BlockCreated[A]): Transition[A] = + if (e.viewNumber == viewNumber && isLeader && phase == Phase.Prepare) { + // TODO: If the block is empty, we could just repeat the agreement on + // the previous Q.C. to simulate being idle, without timing out. + val effects = broadcast { + Prepare(viewNumber, e.block, e.highQC) + } + this -> effects + } else stay + + /** Filter out messages that are completely invalid, + * independent of the current phase and view number, + * i.e. stateless validation. + * + * This check can be performed before for example the + * block contents in the `Prepare` message are validated, + * so that we don't waste time with spam. + */ + def validateMessage( + e: MessageReceived[A] + ): Either[ProtocolError[A], Validated[MessageReceived[A]]] = { + val currLeader = federation.leaderOf(e.message.viewNumber) + val nextLeader = federation.leaderOf(e.message.viewNumber.next) + + e.message match { + case _ if !federation.contains(e.sender) => + Left(NotFromFederation(e)) + + case _: LeaderMessage[_] if e.sender != currLeader => + Left(NotFromLeader(e, currLeader)) + + case m: ReplicaMessage[_] + if !m.isInstanceOf[NewView[_]] && publicKey != currLeader => + Left(NotToLeader(e, currLeader)) + + case _: NewView[_] if publicKey != nextLeader => + Left(NotToLeader(e, nextLeader)) + + case m: Vote[_] if !Signing[A].validate(e.sender, m) => + Left(InvalidVote(e.sender, m)) + + case m: Quorum[_] + if !Signing[A].validate(federation, m.quorumCertificate) => + Left(InvalidQuorumCertificate(e.sender, m.quorumCertificate)) + + case m: NewView[_] if m.prepareQC.phase != Phase.Prepare => + Left(InvalidQuorumCertificate(e.sender, m.prepareQC)) + + case m: NewView[_] if !Signing[A].validate(federation, m.prepareQC) => + Left(InvalidQuorumCertificate(e.sender, m.prepareQC)) + + case m: Prepare[_] if !Signing[A].validate(federation, m.highQC) => + Left(InvalidQuorumCertificate(e.sender, m.highQC)) + + case _ => + Right(Validated[MessageReceived[A]](e)) + } + } + + /** Handle an incoming message that has already gone through partial validation: + * + * The sender is verified by the network layer and retrieved from the + * lower level protocol message; we know the signatures are correct; + * and the contents of any proposed block have been validated as well, + * so they are safe to be voted on. + * + * Return the updated state and any effects to be carried out in response, + * or an error, so that mismatches can be traced. Discrepancies can arise + * from the state being different or have changed since the message originally + * received. + * + * The structure of the method tries to match the pseudo code of `Algorithm 2` + * in the HotStuff paper. + */ + def handleMessage( + e: Validated[MessageReceived[A]] + ): TransitionAttempt[A] = + phase match { + // Leader: Collect NewViews, create block, boradcast Prepare + // Replica: Wait for Prepare, check safe extension, vote Prepare, move to PreCommit. + case Phase.Prepare => + matchingMsg(e) { + case m: NewView[_] if m.viewNumber == viewNumber.prev && isLeader => + Right(addNewViewAndMaybeCreateBlock(e.sender, m)) + + case m: Prepare[_] if matchingLeader(e) => + if (isSafe(m)) { + val blockHash = Block[A].blockHash(m.block) + val effects = Seq( + sendVote(Phase.Prepare, blockHash) + ) + val next = moveTo(Phase.PreCommit).copy( + preparedBlock = m.block + ) + Right(next -> effects) + } else { + Left(UnsafeExtension(e.sender, m)) + } + } + + // Leader: Collect Prepare votes, broadcast Prepare Q.C. + // Replica: Wait for Prepare Q.C, save prepareQC, vote PreCommit, move to Commit. + case Phase.PreCommit => + matchingMsg(e) { + handleVotes(e, Phase.Prepare) orElse + handleQuorum(e, Phase.Prepare) { m => + val effects = Seq( + sendVote(Phase.PreCommit, m.quorumCertificate.blockHash), + SaveBlock(preparedBlock) + ) + val next = moveTo(Phase.Commit).copy( + prepareQC = m.quorumCertificate + ) + next -> effects + } + } + + // Leader: Collect PreCommit votes, broadcast PreCommit Q.C. + // Replica: Wait for PreCommit Q.C., save lockedQC, vote Commit, move to Decide. + case Phase.Commit => + matchingMsg(e) { + handleVotes(e, Phase.PreCommit) orElse + handleQuorum(e, Phase.PreCommit) { m => + val effects = Seq( + sendVote(Phase.Commit, m.quorumCertificate.blockHash) + ) + val next = moveTo(Phase.Decide).copy( + lockedQC = m.quorumCertificate + ) + next -> effects + } + } + + // Leader: Collect Commit votes, broadcast Commit Q.C. + // Replica: Wait for Commit Q.C., execute block, send NewView, move to Prepare. + case Phase.Decide => + matchingMsg(e) { + handleVotes(e, Phase.Commit) orElse + handleQuorum(e, Phase.Commit) { m => + handleNextView(NextView(viewNumber)) match { + case (next, effects) => + val withExec = ExecuteBlocks( + lastExecutedBlockHash, + m.quorumCertificate + ) +: effects + + val withLast = next.copy(commitQC = m.quorumCertificate) + + withLast -> withExec + } + } + } + } + + /** The leader's message handling is the same across all phases: + * add the vote to the list; if we reached `n-f` then combine + * into a Q.C. and broadcast. + * + * It can also receive messages beyond the `n-f` it needed, + * which it can ignore. + */ + private def handleVotes( + event: MessageReceived[A], + phase: VotingPhase + ): PartialFunction[Message[A], TransitionAttempt[A]] = { + // Check that a vote is compatible with our current expectations. + case v: Vote[_] + if isLeader && v.viewNumber == viewNumber && + v.phase == phase && + v.blockHash == preparedBlockHash => + Right(addVoteAndMaybeBroadcastQC(v)) + + // Once the leader moves on to the next phase, it can still receive votes + // for the previous one. These can be ignored, they are not unexpected. + case v: Vote[_] + if isLeader && + v.viewNumber == viewNumber && + v.phase.isBefore(phase) && + v.blockHash == preparedBlockHash => + Right(stay) + + // Ignore votes for other blocks. + case v: Vote[_] + if isLeader && v.viewNumber == viewNumber && + v.phase == phase && + v.blockHash != preparedBlockHash => + Left(UnexpectedBlockHash(event, preparedBlockHash)) + + case v: NewView[_] if isLeader && v.viewNumber == viewNumber.prev => + Right(stay) + } + + private def handleQuorum( + event: Validated[MessageReceived[A]], + phase: VotingPhase + )( + f: Quorum[A] => Transition[A] + ): PartialFunction[Message[A], TransitionAttempt[A]] = { + case m: Quorum[_] + if matchingLeader(event) && + m.quorumCertificate.viewNumber == viewNumber && + m.quorumCertificate.phase == phase && + m.quorumCertificate.blockHash == preparedBlockHash => + Right(f(m)) + + case m: Quorum[_] + if matchingLeader(event) && + m.quorumCertificate.viewNumber == viewNumber && + m.quorumCertificate.phase == phase && + m.quorumCertificate.blockHash != preparedBlockHash => + Left(UnexpectedBlockHash(event, preparedBlockHash)) + } + + /** Categorize unexpected messages into ones that can be re-queued or discarded. + * + * At this point we already know that the messages have been validated once, + * so at at least they are consistent with their own view, e.g. sending to the + * leader of their own view. + */ + private def handleUnexpected(e: MessageReceived[A]): ProtocolError[A] = { + e.message match { + case m: NewView[_] if m.viewNumber >= viewNumber => + TooEarly(e, m.viewNumber.next, Phase.Prepare) + + case m: Prepare[_] if m.viewNumber > viewNumber => + TooEarly(e, m.viewNumber, Phase.Prepare) + + case m: Vote[_] + if m.viewNumber > viewNumber || + m.viewNumber == viewNumber && m.phase.isAfter(phase.prev) => + TooEarly(e, m.viewNumber, m.phase.next) + + case m: Quorum[_] + if m.quorumCertificate.viewNumber > viewNumber || + m.quorumCertificate.viewNumber == viewNumber && + m.quorumCertificate.phase.isAfter(phase.prev) => + TooEarly(e, m.viewNumber, m.quorumCertificate.phase.next) + + case _ => + Unexpected(e) + } + } + + /** Try to match a message to expectations, or return Unexpected. */ + private def matchingMsg(e: MessageReceived[A])( + pf: PartialFunction[Message[A], TransitionAttempt[A]] + ): TransitionAttempt[A] = + pf.lift(e.message).getOrElse(Left(handleUnexpected(e))) + + /** Check that a message is coming from the view leader and is for the current phase. */ + private def matchingLeader(e: MessageReceived[A]): Boolean = + e.message.viewNumber == viewNumber && + e.sender == federation.leaderOf(viewNumber) + + /** Broadcast a message from the leader to all replicas. + * + * This includes the leader sending a message to itself, + * because the leader is a replica as well. The effect + * system should take care that these messages don't + * try to go over the network. + * + * NOTE: Some messages trigger transitions; it's best + * if the message sent to the leader by itself is handled + * before the other messages are sent out to avoid any + * votes coming in return coming in phases that don't + * yet expect them. + */ + private def broadcast(m: Message[A]): Seq[Effect[A]] = + federation.publicKeys.map { pk => + SendMessage(pk, m) + } + + /** Produce a vote with the current view number. */ + private def vote(phase: VotingPhase, blockHash: A#Hash): Vote[A] = { + val signature = Signing[A].sign(signingKey, phase, viewNumber, blockHash) + Vote(viewNumber, phase, blockHash, signature) + } + + private def sendVote(phase: VotingPhase, blockHash: A#Hash): SendMessage[A] = + SendMessage(leader, vote(phase, blockHash)) + + /** Check that the proposed new block extends the locked Q.C. (safety) + * or that the Quorum Certificate is newer than the locked Q.C. (liveness). + */ + private def isSafe(m: Prepare[A]): Boolean = { + val valid = isExtension(m.block, m.highQC) + val safe = isExtension(m.block, lockedQC) + val live = m.highQC.viewNumber > lockedQC.viewNumber + + valid && (safe || live) + } + + /** Check that a block extends from the one in the Q.C. + * + * Currently only allows direct parent-child relationship, + * which means each leader is expected to create max 1 block + * on top of the previous high Q.C. + */ + private def isExtension(block: A#Block, qc: QuorumCertificate[A]): Boolean = + qc.blockHash == Block[A].parentBlockHash(block) + + /** Register a new vote; if there are enough to form a new Q.C., + * do so and broadcast it. + */ + private def addVoteAndMaybeBroadcastQC(vote: Vote[A]): Transition[A] = { + // `matchingVote` made sure all votes are for the same content, + // and `moveTo` clears the votes, so they should be uniform. + val next = copy(votes = votes + vote) + + // Only make the quorum certificate once. + val effects = + if (votes.size < quorumSize && next.votes.size == quorumSize) { + val vs = next.votes.toSeq + val qc = QuorumCertificate( + phase = vs.head.phase, + viewNumber = vs.head.viewNumber, + blockHash = vs.head.blockHash, + signature = Signing[A].combine(vs.map(_.signature)) + ) + broadcast { + Quorum(viewNumber, qc) + } + } else Nil + + // The move to the next phase will be triggered when the Q.C. is delivered. + next -> effects + } + + /** Register a NewView from a replica; if there are enough, select the High Q.C. and create a block. */ + private def addNewViewAndMaybeCreateBlock( + sender: A#PKey, + newView: NewView[A] + ): Transition[A] = { + // We already checked that these are for the current view. + val next = copy(newViews = + newViews.updated( + sender, + newViews.get(sender).fold(newView) { oldView => + if (newView.prepareQC.viewNumber > oldView.prepareQC.viewNumber) + newView + else oldView + } + ) + ) + + // Only make a block once. + val effects = + if (newViews.size < quorumSize && next.newViews.size == quorumSize) { + List( + CreateBlock( + viewNumber, + highQC = next.newViews.values.map(_.prepareQC).maxBy(_.viewNumber) + ) + ) + } else Nil + + // The move to the next phase will be triggered when the block is created. + next -> effects + } +} + +object ProtocolState { + + /** The result of state transitions are the next state and some effects + * that can be carried out in parallel. + */ + type Transition[A <: Agreement] = (ProtocolState[A], Seq[Effect[A]]) + + type TransitionAttempt[A <: Agreement] = + Either[ProtocolError[A], Transition[A]] + + /** Return an initial set of effects; at the minimum the timeout for the first round. */ + def init[A <: Agreement](state: ProtocolState[A]): Seq[Effect[A]] = + List(Effect.ScheduleNextView(state.viewNumber, state.timeout)) +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/QuorumCertificate.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/QuorumCertificate.scala new file mode 100644 index 00000000..1390e062 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/QuorumCertificate.scala @@ -0,0 +1,14 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.crypto.GroupSignature +import io.iohk.metronome.hotstuff.consensus.ViewNumber + +/** A Quorum Certifcate (QC) over a tuple (message-type, view-number, block-hash) is a data type + * that combines a collection of signatures for the same tuple signed by (n − f) replicas. + */ +case class QuorumCertificate[A <: Agreement]( + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash, + signature: GroupSignature[A#PKey, (VotingPhase, ViewNumber, A#Hash), A#GSig] +) diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1Agreement.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1Agreement.scala new file mode 100644 index 00000000..0aa5103e --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1Agreement.scala @@ -0,0 +1,12 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.ethereum.crypto.ECDSASignature +import io.iohk.metronome.crypto.{ECPrivateKey, ECPublicKey} + +trait Secp256k1Agreement extends Agreement { + override final type SKey = ECPrivateKey + override final type PKey = ECPublicKey + override final type PSig = ECDSASignature + // TODO (PM-2935): Replace list with theshold signatures. + override final type GSig = List[ECDSASignature] +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1Signing.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1Signing.scala new file mode 100644 index 00000000..60e44628 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1Signing.scala @@ -0,0 +1,87 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.ethereum.crypto.ECDSASignature +import io.iohk.metronome.crypto.hash.Keccak256 +import io.iohk.metronome.crypto.{ + ECPrivateKey, + ECPublicKey, + GroupSignature, + PartialSignature +} +import io.iohk.metronome.hotstuff.consensus.basic.Signing.{GroupSig, PartialSig} +import io.iohk.metronome.hotstuff.consensus.{Federation, ViewNumber} +import scodec.bits.ByteVector + +/** Facilitates a Secp256k1 elliptic curve signing scheme using + * `io.iohk.ethereum.crypto.ECDSASignature` + * A group signature is simply a concatenation (sequence) of partial signatures + */ +class Secp256k1Signing[A <: Secp256k1Agreement]( + contentSerializer: (VotingPhase, ViewNumber, A#Hash) => ByteVector +) extends Signing[A] { + + override def sign( + signingKey: ECPrivateKey, + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): PartialSig[A] = { + val msgHash = contentHash(phase, viewNumber, blockHash) + PartialSignature(ECDSASignature.sign(msgHash, signingKey.underlying)) + } + + override def combine( + signatures: Seq[PartialSig[A]] + ): GroupSig[A] = + GroupSignature(signatures.map(_.sig).toList) + + /** Validate that partial signature was created by a given public key. + * + * Check that the signer is part of the federation. + */ + override def validate( + publicKey: ECPublicKey, + signature: PartialSig[A], + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): Boolean = { + val msgHash = contentHash(phase, viewNumber, blockHash) + signature.sig + .publicKey(msgHash) + .map(ECPublicKey(_)) + .contains(publicKey) + } + + /** Validate a group signature. + * + * Check that enough members of the federation signed, + * and only the members. + */ + override def validate( + federation: Federation[ECPublicKey], + signature: GroupSig[A], + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): Boolean = { + val msgHash = contentHash(phase, viewNumber, blockHash) + val signers = + signature.sig + .flatMap(s => s.publicKey(msgHash).map(ECPublicKey(_))) + .toSet + + val areUniqueSigners = signers.size == signature.sig.size + val areFederationMembers = (signers -- federation.publicKeys).isEmpty + val isQuorumReached = signers.size == federation.quorumSize + + areUniqueSigners && areFederationMembers && isQuorumReached + } + + private def contentHash( + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): Array[Byte] = + Keccak256(contentSerializer(phase, viewNumber, blockHash)).toArray +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Signing.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Signing.scala new file mode 100644 index 00000000..f5cf31b7 --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/basic/Signing.scala @@ -0,0 +1,76 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.crypto.{GroupSignature, PartialSignature} +import io.iohk.metronome.hotstuff.consensus.{Federation, ViewNumber} +import scodec.bits.ByteVector + +trait Signing[A <: Agreement] { + + def sign( + signingKey: A#SKey, + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): Signing.PartialSig[A] + + def combine( + signatures: Seq[Signing.PartialSig[A]] + ): Signing.GroupSig[A] + + /** Validate that partial signature was created by a given public key. */ + def validate( + publicKey: A#PKey, + signature: Signing.PartialSig[A], + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): Boolean + + /** Validate a group signature. + * + * Check that enough members of the federation signed, + * and only the members. + */ + def validate( + federation: Federation[A#PKey], + signature: Signing.GroupSig[A], + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: A#Hash + ): Boolean + + def validate(sender: A#PKey, vote: Message.Vote[A]): Boolean = + validate( + sender, + vote.signature, + vote.phase, + vote.viewNumber, + vote.blockHash + ) + + def validate( + federation: Federation[A#PKey], + quorumCertificate: QuorumCertificate[A] + ): Boolean = + validate( + federation, + quorumCertificate.signature, + quorumCertificate.phase, + quorumCertificate.viewNumber, + quorumCertificate.blockHash + ) +} + +object Signing { + def apply[A <: Agreement: Signing]: Signing[A] = implicitly[Signing[A]] + + def secp256k1[A <: Secp256k1Agreement]( + contentSerializer: (VotingPhase, ViewNumber, A#Hash) => ByteVector + ): Signing[A] = new Secp256k1Signing[A](contentSerializer) + + type PartialSig[A <: Agreement] = + PartialSignature[A#PKey, (VotingPhase, ViewNumber, A#Hash), A#PSig] + + type GroupSig[A <: Agreement] = + GroupSignature[A#PKey, (VotingPhase, ViewNumber, A#Hash), A#GSig] +} diff --git a/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/package.scala b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/package.scala new file mode 100644 index 00000000..94f9988f --- /dev/null +++ b/metronome/hotstuff/consensus/src/io/iohk/metronome/hotstuff/consensus/package.scala @@ -0,0 +1,5 @@ +package io.iohk.metronome.hotstuff + +package object consensus { + type ViewNumber = ViewNumber.Tagged +} diff --git a/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/ArbitraryInstances.scala b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/ArbitraryInstances.scala new file mode 100644 index 00000000..16084920 --- /dev/null +++ b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/ArbitraryInstances.scala @@ -0,0 +1,32 @@ +package io.iohk.metronome.hotstuff.consensus + +import io.iohk.metronome.crypto.hash.Hash +import io.iohk.metronome.hotstuff.consensus.basic.Phase.{ + Commit, + PreCommit, + Prepare +} +import io.iohk.metronome.hotstuff.consensus.basic.VotingPhase +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.{Arbitrary, Gen} +import scodec.bits.ByteVector + +trait ArbitraryInstances { + + def sample[T: Arbitrary]: T = arbitrary[T].sample.get + + implicit val arbViewNumber: Arbitrary[ViewNumber] = Arbitrary { + Gen.posNum[Long].map(ViewNumber(_)) + } + + implicit val arbVotingPhase: Arbitrary[VotingPhase] = Arbitrary { + Gen.oneOf(Prepare, PreCommit, Commit) + } + + implicit val arbHash: Arbitrary[Hash] = + Arbitrary { + Gen.listOfN(32, arbitrary[Byte]).map(ByteVector(_)).map(Hash(_)) + } +} + +object ArbitraryInstances extends ArbitraryInstances diff --git a/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/FederationSpec.scala b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/FederationSpec.scala new file mode 100644 index 00000000..4e94f78f --- /dev/null +++ b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/FederationSpec.scala @@ -0,0 +1,60 @@ +package io.iohk.metronome.hotstuff.consensus + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.Inside +import org.scalatest.prop.TableDrivenPropertyChecks._ + +class FederationSpec extends AnyFlatSpec with Matchers with Inside { + + implicit val ls = LeaderSelection.RoundRobin + + behavior of "Federation" + + it should "not create an empty federation" in { + Federation(Vector.empty).isLeft shouldBe true + } + + it should "not create a federation with duplicate keys" in { + Federation(Vector(1, 2, 1)).isLeft shouldBe true + } + + it should "not create a federation with too high configured f" in { + Federation(1 to 4, maxFaulty = 2).isLeft shouldBe true + } + + it should "determine the correct f and q based on n" in { + val examples = Table( + ("n", "f", "q"), + (10, 3, 7), + (1, 0, 1), + (3, 0, 2), + (4, 1, 3) + ) + forAll(examples) { case (n, f, q) => + inside(Federation(1 to n)) { case Right(federation) => + federation.maxFaulty shouldBe f + federation.quorumSize shouldBe q + } + } + } + + it should "use lower quorum size if there are less faulties" in { + val examples = Table( + ("n", "f", "q"), + (10, 2, 7), + (10, 1, 6), + (10, 0, 6), + (9, 0, 5), + (100, 0, 51), + (100, 1, 51) + ) + forAll(examples) { case (n, f, q) => + inside(Federation(1 to n, f)) { case Right(federation) => + federation.maxFaulty shouldBe f + federation.quorumSize shouldBe q + } + } + } + +} diff --git a/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/LeaderSelectionProps.scala b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/LeaderSelectionProps.scala new file mode 100644 index 00000000..ff8e4c1e --- /dev/null +++ b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/LeaderSelectionProps.scala @@ -0,0 +1,41 @@ +package io.iohk.metronome.hotstuff.consensus + +import io.iohk.metronome.core.Tagger +import io.iohk.metronome.hotstuff.consensus.ArbitraryInstances._ +import org.scalacheck.Prop.forAll +import org.scalacheck._ + +abstract class LeaderSelectionProps(name: String, val selector: LeaderSelection) + extends Properties(name) { + + object Size extends Tagger[Int] + type Size = Size.Tagged + + implicit val arbFederationSize: Arbitrary[Size] = Arbitrary { + Gen.posNum[Int].map(Size(_)) + } + + property("leaderOf") = forAll { (viewNumber: ViewNumber, size: Size) => + val idx = selector.leaderOf(viewNumber, size) + 0 <= idx && idx < size + } +} + +object RoundRobinSelectionProps + extends LeaderSelectionProps( + "LeaderSelection.RoundRobin", + LeaderSelection.RoundRobin + ) { + + property("round-robin") = forAll { (viewNumber: ViewNumber, size: Size) => + val idx0 = selector.leaderOf(viewNumber, size) + val idx1 = selector.leaderOf(viewNumber.next, size) + idx1 == idx0 + 1 || idx0 == size - 1 && idx1 == 0 + } +} + +object HashingSelectionProps + extends LeaderSelectionProps( + "LeaderSelection.Hashing", + LeaderSelection.Hashing + ) diff --git a/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolStateProps.scala b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolStateProps.scala new file mode 100644 index 00000000..a363a36a --- /dev/null +++ b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/basic/ProtocolStateProps.scala @@ -0,0 +1,948 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import io.iohk.metronome.crypto.{GroupSignature, PartialSignature} +import io.iohk.metronome.hotstuff.consensus.{ + ViewNumber, + Federation, + LeaderSelection +} +import org.scalacheck.commands.Commands +import org.scalacheck.{Properties, Gen, Prop} +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Prop.{propBoolean, all, falsified} +import scala.annotation.nowarn +import scala.concurrent.duration._ +import scala.util.{Try, Failure, Success} + +object ProtocolStateProps extends Properties("Basic HotStuff") { + + property("protocol") = ProtocolStateCommands.property() + +} + +/** State machine tests for the Basic HotStuff protocol. + * + * The `Model` class has enough reflection of the state so that we can generate valid + * and invalid commands using `genCommand`. Each `Command`, has its individual post-condition + * check comparing the model state to the actual protocol results. + */ +object ProtocolStateCommands extends Commands { + + case class TestBlock( + blockHash: Int, + parentBlockHash: Int, + command: String + ) + + object TestAgreement extends Agreement { + type Block = TestBlock + type Hash = Int + type PSig = Long + type GSig = Seq[Long] + type PKey = Int + type SKey = Int + } + type TestAgreement = TestAgreement.type + + val genesis = + TestBlock(blockHash = 0, parentBlockHash = -1, command = "") + + val genesisQC = QuorumCertificate[TestAgreement]( + phase = Phase.Prepare, + viewNumber = ViewNumber(0), + blockHash = genesis.blockHash, + signature = GroupSignature(Nil) + ) + + implicit val block: Block[TestAgreement] = new Block[TestAgreement] { + override def blockHash(b: TestBlock) = b.blockHash + override def parentBlockHash(b: TestBlock) = b.parentBlockHash + override def height(b: TestBlock): Long = 0 // Not used by this model. + override def isValid(b: TestBlock) = true + } + + implicit val leaderSelection = LeaderSelection.Hashing + + // Going to use publicKey == -1 * signingKey. + def mockSigningKey(pk: TestAgreement.PKey): TestAgreement.SKey = -1 * pk + + // Mock signatures. + implicit val mockSigning: Signing[TestAgreement] = + new Signing[TestAgreement] { + private def hash( + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: TestAgreement.Hash + ): TestAgreement.Hash = + (phase, viewNumber, blockHash).hashCode + + private def isGenesis( + viewNumber: ViewNumber, + blockHash: TestAgreement.Hash + ): Boolean = + viewNumber == genesisQC.viewNumber && + blockHash == genesisQC.blockHash + + private def sign( + sk: TestAgreement.SKey, + h: TestAgreement.Hash + ): TestAgreement.PSig = + h + sk + + private def unsign( + s: TestAgreement.PSig, + h: TestAgreement.Hash + ): TestAgreement.PKey = + ((s - h) * -1).toInt + + override def sign( + signingKey: TestAgreement#SKey, + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: TestAgreement.Hash + ): Signing.PartialSig[TestAgreement] = { + val h = hash(phase, viewNumber, blockHash) + val s = sign(signingKey, h) + PartialSignature(s) + } + + override def combine( + signatures: Seq[Signing.PartialSig[TestAgreement]] + ): Signing.GroupSig[TestAgreement] = + GroupSignature(signatures.map(_.sig)) + + override def validate( + publicKey: TestAgreement.PKey, + signature: Signing.PartialSig[TestAgreement], + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: TestAgreement.Hash + ): Boolean = { + val h = hash(phase, viewNumber, blockHash) + publicKey == unsign(signature.sig, h) + } + + override def validate( + federation: Federation[TestAgreement.PKey], + signature: Signing.GroupSig[TestAgreement], + phase: VotingPhase, + viewNumber: ViewNumber, + blockHash: TestAgreement.Hash + ): Boolean = { + if (isGenesis(viewNumber, blockHash)) { + signature.sig.isEmpty + } else { + val h = hash(phase, viewNumber, blockHash) + + signature.sig.size == federation.quorumSize && + signature.sig.forall { sig => + federation.publicKeys.exists { publicKey => + publicKey == unsign(sig, h) + } + } + } + } + } + + case class Model( + n: Int, + f: Int, + viewNumber: ViewNumber, + phase: Phase, + federation: Vector[TestAgreement.PKey], + ownIndex: Int, + votesFrom: Set[TestAgreement.PKey], + newViewsFrom: Set[TestAgreement.PKey], + newViewsHighQC: QuorumCertificate[TestAgreement], + prepareQCs: List[QuorumCertificate[TestAgreement]], + maybeBlockHash: Option[TestAgreement.Hash] + ) { + def publicKey = federation(ownIndex) + + // Using a signing key that works with the mock validation. + def signingKey = mockSigningKey(publicKey) + + def leaderIndex = leaderSelection.leaderOf(viewNumber, n) + def isLeader = leaderIndex == ownIndex + def leader = federation(leaderIndex) + + def quorumSize = (n + f) / 2 + 1 + } + + // Keep a variable state in our System Under Test. + class Protocol(var state: ProtocolState[TestAgreement]) + + type Sut = Protocol + type State = Model + + @nowarn + override def canCreateNewSut( + newState: State, + initSuts: Traversable[State], + runningSuts: Traversable[Sut] + ): Boolean = true + + override def initialPreCondition(state: State): Boolean = + state.viewNumber == 1 && + state.phase == Phase.Prepare && + state.votesFrom.isEmpty && + state.newViewsFrom.isEmpty + + override def newSut(state: State): Sut = + new Protocol( + ProtocolState[TestAgreement]( + viewNumber = ViewNumber(state.viewNumber), + phase = state.phase, + publicKey = state.publicKey, + signingKey = state.signingKey, + federation = Federation(state.federation, state.f) + .getOrElse(sys.error("Invalid federation!")), + prepareQC = genesisQC, + lockedQC = genesisQC, + commitQC = genesisQC, + preparedBlock = genesis, + timeout = 10.seconds, + votes = Set.empty, + newViews = Map.empty + ) + ) + + override def destroySut(sut: Sut): Unit = () + + override def genInitialState: Gen[State] = + for { + n <- Gen.choose(1, 10) + f <- Gen.choose(0, (n - 1) / 3) + + ownIndex <- Gen.choose(0, n - 1) + + // Create unique keys. + publicKeys <- Gen + .listOfN(n, Gen.posNum[Int]) + .map { ns => + ns.tail.scan(ns.head)(_ + _) + } + .retryUntil(_.size == n) + + } yield Model( + n, + f, + viewNumber = ViewNumber(1), + phase = Phase.Prepare, + federation = publicKeys.toVector, + ownIndex = ownIndex, + votesFrom = Set.empty, + newViewsFrom = Set.empty, + newViewsHighQC = genesisQC, + prepareQCs = List(genesisQC), + maybeBlockHash = None + ) + + /** Generate valid and invalid commands depending on state. + * + * Invalid commands are marked as such, so we don't have to repeat validations here + * to tell what we expect the response to be. We can send invalid commands from up + * to `f` Bzyantine members of the federation. The rest should be honest, but they + * might still send commands which are delivered in a different state, e.g. because + * they didn't have the data available to validate a proposal. + */ + override def genCommand(state: State): Gen[Command] = + Gen.frequency( + 7 -> genValid(state), + 2 -> genInvalid(state), + 1 -> genTimeout(state) + ) + + def fail(msg: String) = msg |: falsified + + def votingPhaseFor(phase: Phase): Option[VotingPhase] = + phase match { + case Phase.Prepare => None + case Phase.PreCommit => Some(Phase.Prepare) + case Phase.Commit => Some(Phase.PreCommit) + case Phase.Decide => Some(Phase.Commit) + } + + def genTimeout(state: State): Gen[NextViewCmd] = + Gen.const(NextViewCmd(state.viewNumber)) + + /** Geneerate a valid input for the givens state. */ + def genValid(state: State): Gen[Command] = { + val usables: List[Gen[Command]] = + List( + // The leader may receive NewView any time. + genValidNewView(state) -> + state.isLeader, + // The leader can get a block generated by the host system in Prepare. + genValidBlock(state) -> + (state.phase == Phase.Prepare && state.isLeader && state.maybeBlockHash.isEmpty), + // Replicas can get a Prepared block in Prepare (for the leader this should match the created block). + genValidPrepare(state) -> + (state.phase == Phase.Prepare && + (state.isLeader && state.maybeBlockHash.isDefined || + !state.isLeader && state.maybeBlockHash.isEmpty)), + // The leader can get votes on the block it created, except in Prepare. + genValidVote(state) -> + (state.phase != Phase.Prepare && state.isLeader && state.maybeBlockHash.isDefined), + // Replicas can get a Quroum on the block that was Prepared, except in Prepare. + genValidQuorum(state) -> + (state.phase != Phase.Prepare && state.maybeBlockHash.isDefined) + ).collect { + case (gen, usable) if usable => gen + } + + usables match { + case Nil => genTimeout(state) + case one :: Nil => one + case one :: two :: rest => Gen.oneOf(one, two, rest: _*) + } + } + + /** Take an valid command and turn it invalid. */ + def genInvalid(state: State): Gen[Command] = { + def nextVoting(phase: Phase): VotingPhase = { + phase.next match { + case p: VotingPhase => p + case p => nextVoting(p) + } + } + + def invalidateHash(h: TestAgreement.Hash) = h * 2 + 1 + def invalidateSig(s: TestAgreement.PSig) = s * 2 + 1 + def invalidateViewNumber(v: ViewNumber) = ViewNumber(v + 1000) + def invalidSender = state.federation.sum + 1 + + def invalidateQC( + qc: QuorumCertificate[TestAgreement] + ): Gen[QuorumCertificate[TestAgreement]] = { + Gen.oneOf( + genLazy( + qc.copy[TestAgreement](blockHash = invalidateHash(qc.blockHash)) + ), + genLazy(qc.copy[TestAgreement](phase = nextVoting(qc.phase))) + .suchThat(_.blockHash != genesisQC.blockHash), + genLazy( + qc.copy[TestAgreement](viewNumber = + invalidateViewNumber(qc.viewNumber) + ) + ), + genLazy( + qc.copy[TestAgreement](signature = + // The quorum cert has no items, so add one to make it different. + qc.signature.copy(sig = 0L +: qc.signature.sig.map(invalidateSig)) + ) + ) + ) + } + + implicit class StringOps(label: String) { + def `!`(gen: Gen[MessageCmd]): Gen[InvalidCmd] = + gen.map(cmd => InvalidCmd(label, cmd, isEarly = label == "viewNumber")) + } + + genValid(state) flatMap { + case msg: MessageCmd => + msg match { + case cmd @ NewViewCmd(_, m) => + Gen.oneOf( + "sender" ! genLazy(cmd.copy(sender = invalidSender)), + "viewNumber" ! genLazy( + cmd.copy(message = + m.copy(viewNumber = invalidateViewNumber(m.viewNumber)) + ) + ), + "prepareQC" ! invalidateQC(m.prepareQC).map { qc => + cmd.copy(message = m.copy(prepareQC = qc)) + } + ) + + case cmd @ PrepareCmd(_, m) => + Gen.oneOf( + "sender" ! genLazy(cmd.copy(sender = invalidSender)), + "viewNumber" ! genLazy( + cmd.copy(message = m.copy(viewNumber = m.viewNumber.next)) + ), + "parentBlockHash" ! genLazy( + cmd.copy(message = + m.copy[TestAgreement](block = + m.block + .copy(parentBlockHash = + invalidateHash(m.block.parentBlockHash) + ) + ) + ) + ), + "highQC" ! invalidateQC(m.highQC).map { qc => + cmd.copy(message = m.copy(highQC = qc)) + } + ) + + case cmd @ VoteCmd(_, m) => + Gen.oneOf( + "sender" ! genLazy(cmd.copy(sender = invalidSender)), + "viewNumber" ! genLazy( + cmd.copy(message = + m.copy[TestAgreement](viewNumber = + invalidateViewNumber(m.viewNumber) + ) + ) + ), + "phase" ! genLazy( + cmd.copy(message = + m.copy[TestAgreement](phase = nextVoting(m.phase)) + ) + ), + "blockHash" ! genLazy( + cmd.copy(message = + m.copy[TestAgreement](blockHash = invalidateHash(m.blockHash)) + ) + ), + "signature" ! genLazy( + cmd.copy(message = + m.copy[TestAgreement](signature = + m.signature.copy(sig = invalidateSig(m.signature.sig)) + ) + ) + ) + ) + + case cmd @ QuorumCmd(_, m) => + Gen.oneOf( + "sender" ! genLazy(cmd.copy(sender = invalidSender)), + "quorumCertificate" ! invalidateQC(m.quorumCertificate).map { + qc => + cmd.copy(message = m.copy(quorumCertificate = qc)) + } + ) + } + + // Leave anything else alone. + case other => Gen.const(other) + } + } + + /** A constant expression, but only evaluated if the generator is chosen, + * which allows us to have conditions attached to it. + */ + def genLazy[A](a: => A): Gen[A] = Gen.lzy(Gen.const(a)) + + /** Replica sends a new view with an arbitrary prepare QC. */ + def genValidNewView(state: State): Gen[NewViewCmd] = + for { + s <- Gen.oneOf(state.federation) + qc <- Gen.oneOf(state.prepareQCs) + m = Message.NewView(ViewNumber(state.viewNumber - 1), qc) + } yield NewViewCmd(s, m) + + /** Leader creates a valid block on top of the saved High Q.C. */ + def genValidBlock(state: State): Gen[BlockCreatedCmd] = + for { + c <- arbitrary[String] + h <- genHash + qc = state.prepareQCs.head // So that it's a safe extension. + p = qc.blockHash + b = TestBlock(h, p, c) + e = Event + .BlockCreated[TestAgreement](state.viewNumber, b, qc) + } yield BlockCreatedCmd(e) + + /** Leader sends a valid Prepare command with the generated block. */ + def genValidPrepare(state: State): Gen[PrepareCmd] = + for { + blockCreated <- genValidBlock(state).map(_.event).map { bc => + bc.copy[TestAgreement]( + block = bc.block.copy( + blockHash = state.maybeBlockHash.getOrElse(bc.block.blockHash) + ) + ) + } + } yield { + PrepareCmd( + sender = state.leader, + message = Message.Prepare( + state.viewNumber, + blockCreated.block, + blockCreated.highQC + ) + ) + } + + /** Replica sends a valid vote for the current phase and prepared block. */ + def genValidVote(state: State): Gen[VoteCmd] = + for { + blockHash <- genLazy { + state.maybeBlockHash.getOrElse(sys.error("No block to vote on.")) + } + // The leader is expecting votes for the previous phase. + phase = votingPhaseFor(state.phase).getOrElse( + sys.error(s"No voting phase for ${state.phase}") + ) + sender <- Gen.oneOf(state.federation) + vote = Message.Vote[TestAgreement]( + state.viewNumber, + phase, + blockHash, + signature = mockSigning.sign( + mockSigningKey(sender), + phase, + state.viewNumber, + blockHash + ) + ) + } yield VoteCmd(sender, vote) + + /** Leader sends a valid quorum from the collected votes. */ + def genValidQuorum(state: State): Gen[QuorumCmd] = + for { + blockHash <- genLazy { + state.maybeBlockHash.getOrElse(sys.error("No block for quorum.")) + } + pks <- Gen.pick(state.quorumSize, state.federation) + // The replicas is expecting the Q.C. for the previous phase. + phase = votingPhaseFor(state.phase).getOrElse( + sys.error(s"No voting phase for ${state.phase}") + ) + qc = QuorumCertificate[TestAgreement]( + phase, + state.viewNumber, + blockHash, + signature = mockSigning.combine( + pks.toList.map { pk => + mockSigning.sign( + mockSigningKey(pk), + phase, + state.viewNumber, + blockHash + ) + } + ) + ) + q = Message.Quorum(state.viewNumber, qc) + } yield QuorumCmd(state.leader, q) + + // A positive hash, not the same as Genesis. + val genHash: Gen[TestAgreement.Hash] = + arbitrary[Int].map(math.abs(_) + 1) + + /** Timeout. */ + case class NextViewCmd(viewNumber: ViewNumber) extends Command { + type Result = ProtocolState.Transition[TestAgreement] + + def run(sut: Sut): Result = { + sut.state.handleNextView(Event.NextView(viewNumber)) match { + case result @ (next, _) => + sut.state = next + result + } + } + + def nextState(state: State): State = + state.copy( + viewNumber = ViewNumber(state.viewNumber + 1), + phase = Phase.Prepare, + votesFrom = Set.empty, + // In this model there's not a guaranteed message from the leader to itself. + newViewsFrom = Set.empty, + newViewsHighQC = genesisQC, + maybeBlockHash = None + ) + + def preCondition(state: State): Boolean = + viewNumber == state.viewNumber + + def postCondition(state: Model, result: Try[Result]): Prop = + "NextView" |: { + result match { + case Failure(exception) => + fail(s"unexpected $exception") + + case Success((next, effects)) => + val propNewView = effects + .collectFirst { + case Effect.SendMessage( + recipient, + Message.NewView(viewNumber, prepareQC) + ) => + "sends the new view to the next leader" |: + recipient == next.leader && + viewNumber == state.viewNumber && + prepareQC == next.prepareQC + } + .getOrElse(fail("didn't send the new view")) + + val propSchedule = effects + .collectFirst { + case Effect.ScheduleNextView( + viewNumber, + timeout + ) => + "schedules the next view" |: + viewNumber == next.viewNumber && + timeout == next.timeout + } + .getOrElse(fail("didn't schedule the next view")) + + val propNext = "goes to the next phase" |: + next.phase == Phase.Prepare && + next.viewNumber == state.viewNumber + 1 && + next.votes.isEmpty && + next.newViews.isEmpty + + propNext && + propNewView && + propSchedule && + ("only has the expected effects" |: effects.size == 2) + } + } + } + + /** Common logic of handling a received message */ + sealed trait MessageCmd extends Command { + type Result = ProtocolState.TransitionAttempt[TestAgreement] + + def sender: TestAgreement.PKey + def message: Message[TestAgreement] + + override def run(sut: Protocol): Result = { + val event = Event.MessageReceived(sender, message) + sut.state.validateMessage(event).flatMap(sut.state.handleMessage).map { + case result @ (next, _) => + sut.state = next + result + } + } + } + + /** NewView from a replicas to the leader. */ + case class NewViewCmd( + sender: TestAgreement.PKey, + message: Message.NewView[TestAgreement] + ) extends MessageCmd { + override def nextState(state: State): State = + state.copy( + newViewsFrom = state.newViewsFrom + sender, + newViewsHighQC = + if (message.prepareQC.viewNumber > state.newViewsHighQC.viewNumber) + message.prepareQC + else state.newViewsHighQC + ) + + override def preCondition(state: State): Boolean = + state.isLeader && state.viewNumber == message.viewNumber + 1 + + override def postCondition( + state: Model, + result: Try[Result] + ): Prop = { + val nextS = nextState(state) + "NewView" |: { + if ( + state.phase == Phase.Prepare && + state.newViewsFrom.size != state.quorumSize && + nextS.newViewsFrom.size == state.quorumSize + ) { + result match { + case Success(Right((next, effects))) => + val newViewsMax = nextS.newViewsHighQC.viewNumber + val highestView = effects.headOption match { + case Some(Effect.CreateBlock(_, highQC)) => + highQC.viewNumber.toInt + case _ => -1 + } + + "n-f collected" |: all( + s"stays in the phase (${state.phase} -> ${next.phase})" |: next.phase == state.phase, + "records newView" |: next.newViews.size == state.quorumSize, + "creates a block and nothing else" |: effects.size == 1 && + effects.head.isInstanceOf[Effect.CreateBlock[_]], + s"selects the highest QC: $highestView ?= $newViewsMax" |: highestView == newViewsMax + ) + case err => + fail(s"unexpected $err") + } + } else { + result match { + case Success(Right((next, effects))) => + "n-f not collected" |: all( + s"stays in the same phase (${state.phase} -> ${next.phase})" |: next.phase == state.phase, + "doesn't create more effects" |: effects.isEmpty + ) + case err => + fail(s"unexpected $err") + } + } + } + } + } + + /** The leader handed the block created by the host system. */ + case class BlockCreatedCmd(event: Event.BlockCreated[TestAgreement]) + extends Command { + type Result = ProtocolState.Transition[TestAgreement] + + override def run(sut: Protocol): Result = { + sut.state.handleBlockCreated(event) match { + case result @ (next, _) => + sut.state = next + result + } + } + + override def nextState(state: State): State = + state.copy( + maybeBlockHash = Some(event.block.blockHash) + ) + + override def preCondition(state: State): Boolean = + event.viewNumber == state.viewNumber + + override def postCondition( + state: State, + result: Try[Result] + ): Prop = { + "BlockCreated" |: { + result match { + case Success((next, effects)) => + all( + "stay in Prepare" |: next.phase == Phase.Prepare, + "broadcast to all" |: effects.size == state.federation.size, + all( + effects.map { + case Effect.SendMessage(_, m: Message.Prepare[_]) => + all( + "send prepared block" |: m.block == event.block, + "send highQC" |: m.highQC == event.highQC + ) + case other => + fail(s"expected Prepare message: $other") + }: _* + ) + ) + case Failure(ex) => + fail(s"failed with $ex") + } + } + } + } + + /** Prepare from leader to a replica. */ + case class PrepareCmd( + sender: TestAgreement.PKey, + message: Message.Prepare[TestAgreement] + ) extends MessageCmd { + override def nextState(state: State): State = { + state.copy( + phase = Phase.PreCommit, + maybeBlockHash = Some(message.block.blockHash) + ) + } + + override def preCondition(state: State): Boolean = { + message.viewNumber == state.viewNumber && + state.phase == Phase.Prepare && + (state.isLeader && state.maybeBlockHash.isDefined || + !state.isLeader && state.maybeBlockHash.isEmpty) + } + + override def postCondition( + state: Model, + result: Try[Result] + ): Prop = { + "Prepare" |: { + result match { + case Success(Right((next, effects))) => + all( + "move to PreCommit" |: next.phase == Phase.PreCommit, + "cast a vote" |: effects.size == 1, + effects.head match { + case Effect.SendMessage( + recipient, + Message.Vote(_, phase, blockHash, _) + ) => + all( + "vote Prepare" |: phase == Phase.Prepare, + "send to leader" |: recipient == state.leader, + "vote on block" |: blockHash == message.block.blockHash + ) + case other => + fail(s"unexpected effect $other") + } + ) + case other => + fail(s"unexpected result $other") + } + } + } + } + + /** A Vote from a replica to the leader. */ + case class VoteCmd( + sender: TestAgreement.PKey, + message: Message.Vote[TestAgreement] + ) extends MessageCmd { + override def nextState(state: State): State = + state.copy( + votesFrom = state.votesFrom + sender + ) + + override def preCondition(state: State): Boolean = + state.isLeader && + state.viewNumber == message.viewNumber && + votingPhaseFor(state.phase).contains(message.phase) && + state.maybeBlockHash.contains(message.blockHash) + + override def postCondition(state: Model, result: Try[Result]): Prop = { + "Vote" |: { + result match { + case Success(Right((next, effects))) => + val nextS = nextState(state) + val maybeBroadcast = + if ( + state.votesFrom.size < state.quorumSize && + nextS.votesFrom.size == state.quorumSize + ) { + "n - f collected" |: all( + "broadcast to all" |: effects.size == state.federation.size, + "all messages are quorums" |: all( + effects.map { + case Effect.SendMessage(_, Message.Quorum(_, qc)) => + all( + "quorum is about the current phase" |: qc.phase == message.phase, + "quorum is about the block" |: qc.blockHash == message.blockHash + ) + case other => + fail(s"unexpected effect $other") + }: _* + ) + ) + } else { + "not n - f" |: "not broadcast" |: effects.isEmpty + } + + all( + "stay in the same phase" |: next.phase == state.phase, + maybeBroadcast + ) + + case other => + fail(s"unexpected result $other") + } + } + } + } + + /** A Quorum from the leader to a replica. */ + case class QuorumCmd( + sender: TestAgreement.PKey, + message: Message.Quorum[TestAgreement] + ) extends MessageCmd { + override def nextState(state: State): State = + state.copy( + viewNumber = + if (state.phase == Phase.Decide) state.viewNumber.next + else state.viewNumber, + phase = state.phase match { + case Phase.Prepare => Phase.PreCommit + case Phase.PreCommit => Phase.Commit + case Phase.Commit => Phase.Decide + case Phase.Decide => Phase.Prepare + }, + votesFrom = Set.empty, + newViewsFrom = Set.empty, + maybeBlockHash = + if (state.phase == Phase.Decide) None else state.maybeBlockHash, + prepareQCs = + if (message.quorumCertificate.phase == Phase.Prepare) + message.quorumCertificate :: state.prepareQCs + else state.prepareQCs, + newViewsHighQC = + if (state.phase == Phase.Decide) genesisQC else state.newViewsHighQC + ) + + override def preCondition(state: State): Boolean = + state.viewNumber == message.viewNumber && + votingPhaseFor(state.phase).contains(message.quorumCertificate.phase) && + state.maybeBlockHash.contains(message.quorumCertificate.blockHash) + + override def postCondition( + state: Model, + result: Try[Result] + ): Prop = { + "Quorum" |: { + result match { + case Success(Right((next, effects))) => + val nextS = nextState(state) + all( + "moves to the next state" |: next.phase == nextS.phase, + "votes for the next phase" |: (state.phase == Phase.Decide || + effects + .collectFirst { + case Effect.SendMessage(_, Message.Vote(_, phase, _, _)) => + phase == state.phase + } + .getOrElse(false)), + "makes a decision" |: (state.phase != Phase.Decide || + all( + "executes the block" |: effects.collectFirst { + case _: Effect.ExecuteBlocks[_] => + }.isDefined, + "remembers the executed block" |: + next.lastExecutedBlockHash == message.quorumCertificate.blockHash + )), + "saves the prepared block" |: (state.phase != Phase.PreCommit || + effects.collectFirst { case _: Effect.SaveBlock[_] => + }.isDefined) + ) + + case other => + fail(s"unexpected result $other") + } + } + } + } + + /** Check that a deliberately invalidated command returns a protocol error. */ + case class InvalidCmd(label: String, cmd: MessageCmd, isEarly: Boolean) + extends Command { + type Result = (Boolean, ProtocolState.TransitionAttempt[TestAgreement]) + + // The underlying command should return a `Left`, + // which means it shouldn't update the state. + override def run(sut: Protocol): Result = { + val event = Event.MessageReceived(cmd.sender, cmd.message) + val isStaticallyValid = sut.state.validateMessage(event).isRight + isStaticallyValid -> cmd.run(sut) + } + + // The model state validation is not as sophisticated, + // but because we know this is invalid, we know + // it should not cause any change in state. + override def nextState(state: State): State = + state + + // The invalidation should be strong enough that it doesn't + // become valid during shrinking. + override def preCondition(state: State): Boolean = + true + + override def postCondition( + state: State, + result: Try[Result] + ): Prop = + s"Invalid $label" |: { + result match { + case Success((isStaticallyValid, Left(error))) => + // Ensure that some errors are marked as TooEarly. + "is early" |: + isEarly && isStaticallyValid && error + .isInstanceOf[ProtocolError.TooEarly[_]] || + !isStaticallyValid || + !isEarly + + case other => + fail(s"unexpected result $other") + } + } + + } +} diff --git a/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1SigningProps.scala b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1SigningProps.scala new file mode 100644 index 00000000..25a67e25 --- /dev/null +++ b/metronome/hotstuff/consensus/test/src/io/iohk/metronome/hotstuff/consensus/basic/Secp256k1SigningProps.scala @@ -0,0 +1,162 @@ +package io.iohk.metronome.hotstuff.consensus.basic + +import cats.implicits._ +import io.iohk.metronome.crypto +import io.iohk.metronome.crypto.hash.Hash +import io.iohk.metronome.crypto.{ECKeyPair, ECPublicKey} +import io.iohk.metronome.hotstuff.consensus.ArbitraryInstances._ +import io.iohk.metronome.hotstuff.consensus.{ + Federation, + LeaderSelection, + ViewNumber +} +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Prop.forAll +import org.scalacheck.{Gen, Properties, Test} +import scodec.bits.ByteVector + +import java.security.SecureRandom + +object Secp256k1SigningProps extends Properties("Secp256k1Signing") { + + override def overrideParameters(p: Test.Parameters): Test.Parameters = + p.withMinSuccessfulTests(10) + + object TestAgreement extends Secp256k1Agreement { + type Block = Nothing + type Hash = crypto.hash.Hash + } + type TestAgreement = TestAgreement.type + + def serializer( + phase: VotingPhase, + viewNumber: ViewNumber, + hash: crypto.hash.Hash + ): ByteVector = + ByteVector(phase.toString.getBytes) ++ + ByteVector.fromLong(viewNumber) ++ + hash + + def atLeast[A](n: Int, xs: Iterable[A]): Gen[Seq[A]] = { + require( + xs.size >= n, + s"There has to be at least $n elements to choose from" + ) + Gen.choose(n, xs.size).flatMap(Gen.pick(_, xs)).flatMap(_.toSeq) + } + + val signing = Signing.secp256k1[TestAgreement](serializer) + + val keyPairs = List.fill(20)(ECKeyPair.generate(new SecureRandom)) + + def buildFederation(kps: Iterable[ECKeyPair]): Federation[ECPublicKey] = + Federation(kps.map(_.pub).toIndexedSeq)( + LeaderSelection.RoundRobin + ).valueOr(e => throw new Exception(s"Could not build Federation: $e")) + + property("partialSignatureCreation") = forAll( + Gen.oneOf(keyPairs), + arbitrary[ViewNumber], + arbitrary[VotingPhase], + arbitrary[Hash] + ) { (keyPair, viewNumber, votingPhase, hash) => + val partialSig = signing.sign(keyPair.prv, votingPhase, viewNumber, hash) + signing.validate(keyPair.pub, partialSig, votingPhase, viewNumber, hash) + } + + property("noFalseValidation") = forAll( + Gen.pick(2, keyPairs), + arbitrary[ViewNumber], + arbitrary[VotingPhase], + arbitrary[Hash] + ) { case (kps, viewNumber, votingPhase, hash) => + val Seq(signingKp, validationKp) = kps.toSeq + + val partialSig = signing.sign(signingKp.prv, votingPhase, viewNumber, hash) + + !signing.validate( + validationKp.pub, + partialSig, + votingPhase, + viewNumber, + hash + ) + } + + property("groupSignatureCreation") = forAll( + for { + kps <- Gen.atLeastOne(keyPairs) + fed = buildFederation(kps) + signers <- Gen.pick(fed.quorumSize, kps) + } yield (fed, signers.map(_.prv)), + arbitrary[ViewNumber], + arbitrary[VotingPhase], + arbitrary[Hash] + ) { case ((federation, prvKeys), viewNumber, votingPhase, hash) => + val partialSigs = + prvKeys.map(k => signing.sign(k, votingPhase, viewNumber, hash)) + val groupSig = signing.combine(partialSigs.toList) + + signing.validate(federation, groupSig, votingPhase, viewNumber, hash) + } + + property("groupSignatureNonUniqueSigners") = forAll( + for { + kps <- atLeast(2, keyPairs) + fed = buildFederation(kps) + signers <- Gen.pick(fed.quorumSize - 1, kps) + repeated <- Gen.oneOf(signers) + } yield (kps, signers.map(_.prv), repeated.prv), + arbitrary[ViewNumber], + arbitrary[VotingPhase], + arbitrary[Hash] + ) { case ((kps, prvKeys, repeated), viewNumber, votingPhase, hash) => + val federation = buildFederation(kps) + + val partialSigs = + (repeated +: prvKeys).map(k => + signing.sign(k, votingPhase, viewNumber, hash) + ) + val groupSig = signing.combine(partialSigs.toList) + + !signing.validate(federation, groupSig, votingPhase, viewNumber, hash) + } + + property("groupSignatureForeignSigners") = forAll( + for { + kps <- Gen.atLeastOne(keyPairs) if kps.size < keyPairs.size + fed = buildFederation(kps) + signers <- Gen.pick(fed.quorumSize - 1, kps) + foreign <- Gen.oneOf(keyPairs.diff(kps)) + } yield (fed, signers.map(_.prv), foreign.prv), + arbitrary[ViewNumber], + arbitrary[VotingPhase], + arbitrary[Hash] + ) { case ((federation, prvKeys, foreign), viewNumber, votingPhase, hash) => + val partialSigs = + (foreign +: prvKeys).map(k => + signing.sign(k, votingPhase, viewNumber, hash) + ) + val groupSig = signing.combine(partialSigs.toList) + + !signing.validate(federation, groupSig, votingPhase, viewNumber, hash) + } + + property("groupSignatureNoQuorum") = forAll( + for { + kps <- Gen.atLeastOne(keyPairs) + fed = buildFederation(kps) + n <- Gen.choose(0, kps.size) if n != fed.quorumSize + signers <- Gen.pick(n, kps) + } yield (signers.map(_.prv), fed), + arbitrary[ViewNumber], + arbitrary[VotingPhase], + arbitrary[Hash] + ) { case ((prvKeys, federation), viewNumber, votingPhase, hash) => + val partialSigs = + prvKeys.map(k => signing.sign(k, votingPhase, viewNumber, hash)) + val groupSig = signing.combine(partialSigs.toList) + + !signing.validate(federation, groupSig, votingPhase, viewNumber, hash) + } +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/ApplicationService.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/ApplicationService.scala new file mode 100644 index 00000000..deeda9d3 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/ApplicationService.scala @@ -0,0 +1,37 @@ +package io.iohk.metronome.hotstuff.service + +import cats.data.{NonEmptyVector, NonEmptyList} +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, QuorumCertificate} + +/** Represents the "application" domain to the HotStuff module, + * performing all delegations that HotStuff can't do on its own. + */ +trait ApplicationService[F[_], A <: Agreement] { + // TODO (PM-3109): Create block. + def createBlock(highQC: QuorumCertificate[A]): F[Option[A#Block]] + + // TODO (PM-3132, PM-3133): Block validation. + // Returns None if validation cannot be carried out due to data availability issues within a given timeout. + def validateBlock(block: A#Block): F[Option[Boolean]] + + // TODO (PM-3108, PM-3107, PM-3137, PM-3110): Tell the application to execute a block. + // I cannot be sure that all blocks that get committed to gether fit into memory, + // so we pass them one by one, but all of them are accompanied by the final Commit Q.C. + // and the path of block hashes from the block being executed to the one committed. + // Perhaps the application service can cache the headers if it needs to produce a + // proof of the BFT agreement at the end. + // Returns a flag to indicate whether the block execution results have been persisted, + // whether the block and any corresponding state can be used as a starting point after a restart. + def executeBlock( + block: A#Block, + commitQC: QuorumCertificate[A], + commitPath: NonEmptyList[A#Hash] + ): F[Boolean] + + // TODO (PM-3135): Tell the application to sync any state of the block, i.e. the Ledger. + // The `sources` are peers who most probably have this state. + // The full `block` is given because it may not be persisted yet. + // Return `true` if the block storage can be pruned after this operation from earlier blocks, + // which may not be the case if the application syncs by downloading all the blocks. + def syncState(sources: NonEmptyVector[A#PKey], block: A#Block): F[Boolean] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/ConsensusService.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/ConsensusService.scala new file mode 100644 index 00000000..cd92ed66 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/ConsensusService.scala @@ -0,0 +1,616 @@ +package io.iohk.metronome.hotstuff.service + +import cats.implicits._ +import cats.effect.{Concurrent, Timer, Fiber, Resource, ContextShift} +import cats.effect.concurrent.Ref +import io.iohk.metronome.core.Validated +import io.iohk.metronome.core.fibers.FiberSet +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + Effect, + Event, + ProtocolState, + ProtocolError, + Phase, + Message, + Block, + Signing, + QuorumCertificate +} +import io.iohk.metronome.hotstuff.service.execution.BlockExecutor +import io.iohk.metronome.hotstuff.service.pipes.SyncPipe +import io.iohk.metronome.hotstuff.service.storage.{ + BlockStorage, + ViewStateStorage +} +import io.iohk.metronome.hotstuff.service.tracing.ConsensusTracers +import io.iohk.metronome.networking.{ConnectionHandler, Network} +import io.iohk.metronome.storage.KVStoreRunner +import monix.catnap.ConcurrentQueue +import scala.annotation.tailrec +import scala.collection.immutable.Queue +import scala.util.control.NonFatal +import io.iohk.metronome.hotstuff.service.execution.BlockExecutor + +/** An effectful executor wrapping the pure HotStuff ProtocolState. + * + * It handles the `consensus.basic.Message` events coming from the network. + */ +class ConsensusService[ + F[_]: Timer: Concurrent, + N, + A <: Agreement: Block: Signing +]( + publicKey: A#PKey, + network: Network[F, A#PKey, Message[A]], + appService: ApplicationService[F, A], + blockExecutor: BlockExecutor[F, N, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + stateRef: Ref[F, ProtocolState[A]], + stashRef: Ref[F, ConsensusService.MessageStash[A]], + counterRef: Ref[F, ConsensusService.MessageCounter], + syncPipe: SyncPipe[F, A]#Left, + eventQueue: ConcurrentQueue[F, Event[A]], + fiberSet: FiberSet[F], + maxEarlyViewNumberDiff: Int +)(implicit tracers: ConsensusTracers[F, A], storeRunner: KVStoreRunner[F, N]) { + + import ConsensusService.MessageCounter + + /** Get the current protocol state, perhaps to respond to status requests. */ + def getState: F[ProtocolState[A]] = + stateRef.get + + /** Process incoming network messages. */ + private def processNetworkMessages: F[Unit] = + network.incomingMessages + .mapEval[Unit] { case ConnectionHandler.MessageReceived(from, message) => + validateMessage(Event.MessageReceived(from, message)).flatMap { + case None => + ().pure[F] + case Some(valid) => + syncDependencies(valid) + } + } + .completedL + + /** First round of validation of message to decide if we should process it at all. */ + private def validateMessage( + event: Event.MessageReceived[A] + ): F[Option[Validated[Event.MessageReceived[A]]]] = + stateRef.get.flatMap { state => + state + .validateMessage(event) + .map(m => m: Event.MessageReceived[A]) match { + case Left(error) => + protocolError(error).as(none) + + case Right( + Event.MessageReceived( + sender, + message @ Message.Prepare(_, _, highQC) + ) + ) if state.commitQC.viewNumber > highQC.viewNumber => + // The sender is building on a block that is older than the committed one. + // This could be an attack, forcing us to re-download blocks we already pruned. + protocolError(ProtocolError.UnsafeExtension[A](sender, message)) + .as(none) + + case Right(valid) if valid.message.viewNumber < state.viewNumber => + tracers.fromPast(valid) >> + counterRef.update(_.incPast).as(none) + + case Right(valid) + if valid.message.viewNumber > state.viewNumber + maxEarlyViewNumberDiff => + tracers.fromFuture(valid) >> + counterRef.update(_.incFuture).as(none) + + case Right(valid) => + // We know that the message is to/from the leader and it's properly signed, + // althought it may not match our current state, which we'll see later. + counterRef.update(_.incPresent).as(validated(valid).some) + } + } + + /** Synchronize any missing block dependencies, then enqueue the event for final processing. */ + private def syncDependencies( + message: Validated[Event.MessageReceived[A]] + ): F[Unit] = { + import Message._ + // Only syncing Prepare messages. They have the `highQC` as block parent, + // so we know that is something that is safe to sync, it's not a DoS attack. + // Other messages may be bogus: + // - a Vote can point at a non-existing block to force some download; + // we'd reject it anyway if it doesn't match the state we prepared + // - a Quorum could be a replay of some earlier one, maybe a block we have pruned + // - a NewView is similar, it's best to first wait and select the highest we know + message.message match { + case prepare @ Prepare(_, block, highQC) + if Block[A].parentBlockHash(block) != highQC.blockHash => + // The High Q.C. may be valid, but the block is not built on it. + protocolError(ProtocolError.UnsafeExtension(message.sender, prepare)) + + case prepare: Prepare[_] => + // Carry out syncing and validation asynchronously. + syncAndValidatePrepare(message.sender, prepare) + + case _: Vote[_] => + // Let the ProtocolState reject it if it's not about the prepared block. + enqueueEvent(message) + + case _: Quorum[_] => + // Let the ProtocolState reject it if it's not about the prepared block. + enqueueEvent(message) + + case _: NewView[_] => + // Let's assume that we will have the highest prepare Q.C. available, + // while some can be replays of old data we may not have any more. + // If it turns out we don't have the block after all, we'll figure it + // out in the `CreateBlock` effect, at which point we can time out + // and sync with the `Prepare` message from the next leader. + enqueueEvent(message) + } + } + + /** Trace an invalid message. Could include other penalties as well to the sender. */ + private def protocolError( + error: ProtocolError[A] + ): F[Unit] = + tracers.rejected(error) + + /** Add a Prepare message to the synchronisation and validation queue. + * + * The High Q.C. in the message proves that the parent block is valid + * according to the federation members. + * + * Any missing dependencies should be downloaded and the application asked + * to validate each block in succession as the downloads are finished. + */ + private def syncAndValidatePrepare( + sender: A#PKey, + prepare: Message.Prepare[A] + ): F[Unit] = + syncPipe.send(SyncPipe.PrepareRequest(sender, prepare)) + + /** Process the synchronization result queue. */ + private def processSyncPipe: F[Unit] = + syncPipe.receive + .mapEval[Unit] { + case SyncPipe.PrepareResponse(request, isValid) => + if (isValid) { + enqueueEvent( + validated(Event.MessageReceived(request.sender, request.prepare)) + ) + } else { + protocolError( + ProtocolError.UnsafeExtension(request.sender, request.prepare) + ) + } + + case SyncPipe.StatusResponse(status) => + fastForwardState(status) + } + .completedL + + /** Replace the current protocol state based on what was synced with the federation. */ + private def fastForwardState(status: Status[A]): F[Unit] = { + stateRef.get.flatMap { state => + val forward = state.copy[A]( + viewNumber = status.viewNumber, + prepareQC = status.prepareQC, + commitQC = status.commitQC + ) + // Trigger the next view, so we get proper tracing and effect execution. + tracers.adoptView(status) >> + handleTransition( + forward.handleNextView(Event.NextView(status.viewNumber)) + ) + } + } + + /** Add a validated event to the queue for processing against the protocol state. */ + private def enqueueEvent(event: Validated[Event[A]]): F[Unit] = + eventQueue.offer(event) + + /** Take a single event from the queue, apply it on the state, + * kick off the resulting effects, then recurse. + * + * The effects will communicate their results back to the state + * through the event queue. + */ + private def processEvents: F[Unit] = { + eventQueue.poll.flatMap { event => + stateRef.get.flatMap { state => + val handle: F[Unit] = event match { + case Event.NextView(viewNumber) if viewNumber < state.viewNumber => + ().pure[F] + + case e @ Event.NextView(viewNumber) => + for { + counter <- counterRef.get + _ <- tracers.timeout(viewNumber -> counter) + _ <- maybeRequestStatusSync(viewNumber, counter) + _ <- handleTransition(state.handleNextView(e)) + } yield () + + case e @ Event.MessageReceived(_, _) => + handleTransitionAttempt( + state.handleMessage(Validated[Event.MessageReceived[A]](e)) + ) + + case e @ Event.BlockCreated(_, _, _) => + handleTransition(state.handleBlockCreated(e)) + } + + handle >> processEvents + } + } + } + + /** Request view state synchronisation if we timed out and it looks like we're out of sync. */ + private def maybeRequestStatusSync( + viewNumber: ViewNumber, + counter: MessageCounter + ): F[Unit] = { + // Only requesting a state sync if we haven't received any message that looks to be in sync + // but we have received some from the future. If we have received messages from the past, + // then by the virtue of timeouts they should catch up with us at some point. + val isOutOfSync = counter.present == 0 && counter.future > 0 + + // In the case that there were two groups being in sync within group members, but not with + // each other, than there should be rounds when none of them are leaders and they shouldn't + // receive valid present messages. + val requestSync = + tracers.viewSync(viewNumber) >> + syncPipe.send(SyncPipe.StatusRequest(viewNumber)) + + requestSync.whenA(isOutOfSync) + } + + /** Handle successful state transition: + * - apply local effects on the state + * - schedule other effects to execute in the background + * - if there was a phase or view transition, unstash delayed events + */ + private def handleTransition( + transition: ProtocolState.Transition[A] + ): F[Unit] = { + val (state, effects) = transition + + // Apply local messages to the state before anything else. + val (nextState, nextEffects) = + applySyncEffects(state, effects) + + // Unstash messages before we change state. + captureChanges(nextState) >> + unstash(nextState) >> + stateRef.set(nextState) >> + scheduleEffects(nextEffects) + } + + /** Update the view state with and trace changes when they happen. */ + private def captureChanges(nextState: ProtocolState[A]): F[Unit] = { + stateRef.get.flatMap { state => + def ifChanged[T](get: ProtocolState[A] => T)(f: T => F[Unit]) = { + val prev = get(state) + val next = get(nextState) + f(next).whenA(prev != next) + } + + ifChanged(_.viewNumber)(_ => counterRef.set(MessageCounter.empty)) >> + ifChanged(_.viewNumber)(updateViewNumber) >> + ifChanged(_.prepareQC)(updateQuorum) >> + ifChanged(_.lockedQC)(updateQuorum) >> + ifChanged(_.commitQC)(updateQuorum) + } + } + + private def updateViewNumber(viewNumber: ViewNumber): F[Unit] = + tracers.newView(viewNumber) >> + storeRunner.runReadWrite { + viewStateStorage.setViewNumber(viewNumber) + } + + private def updateQuorum(quorumCertificate: QuorumCertificate[A]): F[Unit] = + tracers.quorum(quorumCertificate) >> + storeRunner.runReadWrite { + viewStateStorage.setQuorumCertificate(quorumCertificate) + } + + /** Requeue messages which arrived too early, but are now due becuase + * the state caught up with them. + */ + private def unstash(nextState: ProtocolState[A]): F[Unit] = + stateRef.get.flatMap { state => + val requeue = for { + dueEvents <- stashRef.modify { + _.unstash(nextState.viewNumber, nextState.phase) + } + _ <- dueEvents.traverse(e => enqueueEvent(validated(e))) + } yield () + + requeue.whenA( + nextState.viewNumber != state.viewNumber || nextState.phase != state.phase + ) + } + + /** Carry out local effects before anything else, + * to eliminate race conditions when a vote sent + * to self would have caused a state transition. + * + * Return the updated state and the effects to be + * carried out asynchornously. + */ + private def applySyncEffects( + state: ProtocolState[A], + effects: Seq[Effect[A]] + ): ProtocolState.Transition[A] = { + @tailrec + def loop( + state: ProtocolState[A], + effectQueue: Queue[Effect[A]], + asyncEffects: List[Effect[A]] + ): ProtocolState.Transition[A] = + effectQueue.dequeueOption match { + case None => + (state, asyncEffects.reverse) + + case (Some((effect, effectQueue))) => + effect match { + case Effect.SendMessage(recipient, message) + if recipient == publicKey => + val event = Event.MessageReceived(recipient, message) + + state.handleMessage(validated(event)) match { + case Left(_) => + // This shouldn't happen, but let's just skip this event here and redeliver it later. + loop(state, effectQueue, effect :: asyncEffects) + + case Right((state, effects)) => + loop(state, effectQueue ++ effects, asyncEffects) + } + + case _ => + loop(state, effectQueue, effect :: asyncEffects) + } + } + + loop(state, Queue(effects: _*), Nil) + } + + /** Try to apply a transition: + * - if it's `TooEarly`, add it to the delayed stash + * - if it's another error, ignore the event + * - otherwise carry out the transition + */ + private def handleTransitionAttempt( + transitionAttempt: ProtocolState.TransitionAttempt[A] + ): F[Unit] = transitionAttempt match { + case Left(error @ ProtocolError.TooEarly(_, _, _)) => + tracers.stashed(error) >> + stashRef.update { _.stash(error) } + + case Left(error) => + protocolError(error) + + case Right(transition) => + handleTransition(transition) + } + + /** Effects can be processed independently of each other in the background. */ + private def scheduleEffects(effects: Seq[Effect[A]]): F[Unit] = + effects.toList.traverse(scheduleEffect).void + + /** Start processing an effect in the background. Add the background fiber + * to the scheduled items so they can be canceled if the service is released. + */ + private def scheduleEffect(effect: Effect[A]): F[Unit] = { + fiberSet.submit(processEffect(effect)).void + } + + /** Process a single effect. This will always be wrapped in a Fiber. */ + private def processEffect(effect: Effect[A]): F[Unit] = { + import Event._ + import Effect._ + + val process = effect match { + case ScheduleNextView(viewNumber, timeout) => + val event = validated(NextView(viewNumber)) + Timer[F].sleep(timeout) >> enqueueEvent(event) + + case CreateBlock(viewNumber, highQC) => + // Ask the application to create a block for us. + appService.createBlock(highQC).flatMap { + case None => + ().pure[F] + + case Some(block) => + enqueueEvent( + validated(Event.BlockCreated(viewNumber, block, highQC)) + ) + } + + case SaveBlock(preparedBlock) => + storeRunner.runReadWrite { + blockStorage.put(preparedBlock) + } + + case effect @ ExecuteBlocks(_, _) => + // Each node may be at a different point in the chain, so how + // long the executions take can vary. We could execute it in + // the forground here, but it may cause the node to lose its + // sync with the other federation members, so the execution + // should be offloaded to another queue. + blockExecutor.enqueue(effect) + + case SendMessage(recipient, message) => + network.sendMessage(recipient, message) + } + + process.handleErrorWith { case NonFatal(ex) => + tracers.error(s"Error processing effect $effect", ex) + } + } + + private def validated(event: Event[A]): Validated[Event[A]] = + Validated[Event[A]](event) + + private def validated( + event: Event.MessageReceived[A] + ): Validated[Event.MessageReceived[A]] = + Validated[Event.MessageReceived[A]](event) +} + +object ConsensusService { + + /** Stash to keep too early messages to be re-queued later. + * + * Every slot just has 1 place per federation member to avoid DoS attacks. + */ + case class MessageStash[A <: Agreement]( + slots: Map[(ViewNumber, Phase), Map[A#PKey, Message[A]]] + ) { + def stash(error: ProtocolError.TooEarly[A]): MessageStash[A] = { + val slotKey = (error.expectedInViewNumber, error.expectedInPhase) + val slot = slots.getOrElse(slotKey, Map.empty) + copy(slots = + slots.updated( + slotKey, + slot.updated(error.event.sender, error.event.message) + ) + ) + } + + def unstash( + dueViewNumber: ViewNumber, + duePhase: Phase + ): (MessageStash[A], List[Event.MessageReceived[A]]) = { + val dueKeys = slots.keySet.filter { case (viewNumber, phase) => + viewNumber < dueViewNumber || + viewNumber == dueViewNumber && + !phase.isAfter(duePhase) + } + + val dueEvents = dueKeys.toList.map(slots).flatten.map { + case (sender, message) => Event.MessageReceived(sender, message) + } + + copy(slots = slots -- dueKeys) -> dueEvents + } + } + object MessageStash { + def empty[A <: Agreement] = MessageStash[A](Map.empty) + } + + /** Count the number of messages received from others in a round, + * to determine whether we're out of sync or not in case of a timeout. + */ + case class MessageCounter( + past: Int, + present: Int, + future: Int + ) { + def incPast = copy(past = past + 1) + def incPresent = copy(present = present + 1) + def incFuture = copy(future = future + 1) + } + object MessageCounter { + val empty = MessageCounter(0, 0, 0) + } + + /** Create a `ConsensusService` instance and start processing events + * in the background, shutting processing down when the resource is + * released. + * + * `initState` is expected to be restored from persistent storage + * instances upon restart. + */ + def apply[ + F[_]: Timer: Concurrent: ContextShift, + N, + A <: Agreement: Block: Signing + ]( + publicKey: A#PKey, + network: Network[F, A#PKey, Message[A]], + appService: ApplicationService[F, A], + blockExecutor: BlockExecutor[F, N, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + syncPipe: SyncPipe[F, A]#Left, + initState: ProtocolState[A], + maxEarlyViewNumberDiff: Int = 1 + )(implicit + tracers: ConsensusTracers[F, A], + storeRunner: KVStoreRunner[F, N] + ): Resource[F, ConsensusService[F, N, A]] = + for { + fiberSet <- FiberSet[F] + + service <- Resource.liftF( + build[F, N, A]( + publicKey, + network, + appService, + blockExecutor, + blockStorage, + viewStateStorage, + syncPipe, + initState, + maxEarlyViewNumberDiff, + fiberSet + ) + ) + + _ <- Concurrent[F].background(service.processNetworkMessages) + _ <- Concurrent[F].background(service.processSyncPipe) + _ <- Concurrent[F].background(service.processEvents) + + initEffects = ProtocolState.init(initState) + _ <- Resource.liftF(service.scheduleEffects(initEffects)) + } yield service + + private def build[ + F[_]: Timer: Concurrent: ContextShift, + N, + A <: Agreement: Block: Signing + ]( + publicKey: A#PKey, + network: Network[F, A#PKey, Message[A]], + appService: ApplicationService[F, A], + blockExecutor: BlockExecutor[F, N, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + syncPipe: SyncPipe[F, A]#Left, + initState: ProtocolState[A], + maxEarlyViewNumberDiff: Int, + fiberSet: FiberSet[F] + )(implicit + tracers: ConsensusTracers[F, A], + storeRunner: KVStoreRunner[F, N] + ): F[ConsensusService[F, N, A]] = + for { + stateRef <- Ref[F].of(initState) + stashRef <- Ref[F].of(MessageStash.empty[A]) + fibersRef <- Ref[F].of(Set.empty[Fiber[F, Unit]]) + counterRef <- Ref[F].of(MessageCounter.empty) + eventQueue <- ConcurrentQueue[F].unbounded[Event[A]](None) + + service = new ConsensusService( + publicKey, + network, + appService, + blockExecutor, + blockStorage, + viewStateStorage, + stateRef, + stashRef, + counterRef, + syncPipe, + eventQueue, + fiberSet, + maxEarlyViewNumberDiff + ) + } yield service +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/HotStuffService.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/HotStuffService.scala new file mode 100644 index 00000000..1aa875e2 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/HotStuffService.scala @@ -0,0 +1,93 @@ +package io.iohk.metronome.hotstuff.service + +import cats.Parallel +import cats.effect.{Concurrent, ContextShift, Resource, Timer} +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + ProtocolState, + Message, + Block, + Signing +} +import io.iohk.metronome.hotstuff.service.execution.BlockExecutor +import io.iohk.metronome.hotstuff.service.messages.{ + HotStuffMessage, + SyncMessage +} +import io.iohk.metronome.hotstuff.service.pipes.SyncPipe +import io.iohk.metronome.hotstuff.service.storage.{ + BlockStorage, + ViewStateStorage +} +import io.iohk.metronome.hotstuff.service.tracing.{ + ConsensusTracers, + SyncTracers +} +import io.iohk.metronome.networking.Network +import io.iohk.metronome.storage.KVStoreRunner + +object HotStuffService { + + /** Start up the HotStuff service stack. */ + def apply[ + F[_]: Concurrent: ContextShift: Timer: Parallel, + N, + A <: Agreement: Block: Signing + ]( + network: Network[F, A#PKey, HotStuffMessage[A]], + appService: ApplicationService[F, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + initState: ProtocolState[A] + )(implicit + consensusTracers: ConsensusTracers[F, A], + syncTracers: SyncTracers[F, A], + storeRunner: KVStoreRunner[F, N] + ): Resource[F, Unit] = + for { + (consensusNetwork, syncNetwork) <- Network + .splitter[F, A#PKey, HotStuffMessage[A], Message[A], SyncMessage[A]]( + network + )( + split = { + case HotStuffMessage.ConsensusMessage(message) => Left(message) + case HotStuffMessage.SyncMessage(message) => Right(message) + }, + merge = { + case Left(message) => HotStuffMessage.ConsensusMessage(message) + case Right(message) => HotStuffMessage.SyncMessage(message) + } + ) + + syncPipe <- Resource.liftF { SyncPipe[F, A] } + + blockExecutor <- BlockExecutor[F, N, A]( + appService, + blockStorage, + viewStateStorage + ) + + consensusService <- ConsensusService( + initState.publicKey, + consensusNetwork, + appService, + blockExecutor, + blockStorage, + viewStateStorage, + syncPipe.left, + initState + ) + + syncService <- SyncService( + initState.publicKey, + initState.federation, + syncNetwork, + appService, + blockExecutor, + blockStorage, + viewStateStorage, + syncPipe.right, + consensusService.getState + ) + } yield () +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/Status.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/Status.scala new file mode 100644 index 00000000..7336ac34 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/Status.scala @@ -0,0 +1,15 @@ +package io.iohk.metronome.hotstuff.service + +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, QuorumCertificate} + +/** Status has all the fields necessary for nodes to sync with each other. + * + * This is to facilitate nodes rejoining the network, + * or re-syncing their views after some network glitch. + */ +case class Status[A <: Agreement]( + viewNumber: ViewNumber, + prepareQC: QuorumCertificate[A], + commitQC: QuorumCertificate[A] +) diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/SyncService.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/SyncService.scala new file mode 100644 index 00000000..fd675782 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/SyncService.scala @@ -0,0 +1,421 @@ +package io.iohk.metronome.hotstuff.service + +import cats.implicits._ +import cats.Parallel +import cats.effect.{Sync, Resource, Concurrent, ContextShift, Timer} +import io.iohk.metronome.core.fibers.FiberMap +import io.iohk.metronome.core.messages.{ + RPCMessageCompanion, + RPCPair, + RPCTracker +} +import io.iohk.metronome.hotstuff.consensus.{Federation, ViewNumber} +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + ProtocolState, + Block, + Signing +} +import io.iohk.metronome.hotstuff.service.execution.BlockExecutor +import io.iohk.metronome.hotstuff.service.messages.SyncMessage +import io.iohk.metronome.hotstuff.service.pipes.SyncPipe +import io.iohk.metronome.hotstuff.service.storage.{ + BlockStorage, + ViewStateStorage +} +import io.iohk.metronome.hotstuff.service.sync.{ + BlockSynchronizer, + ViewSynchronizer +} +import io.iohk.metronome.hotstuff.service.tracing.SyncTracers +import io.iohk.metronome.networking.{ConnectionHandler, Network} +import io.iohk.metronome.storage.KVStoreRunner +import scala.util.control.NonFatal +import scala.concurrent.duration._ +import scala.reflect.ClassTag + +/** The `SyncService` handles the `SyncMessage`s coming from the network, + * i.e. serving block and status requests, as well as receive responses + * for outgoing requests for missing dependencies. + * + * It will match up the `requestId`s in the responses and discard any + * unsolicited message. + * + * The block and view synchronisation components will use this service + * to send requests to the network. + */ +class SyncService[F[_]: Concurrent: ContextShift, N, A <: Agreement: Block]( + publicKey: A#PKey, + network: Network[F, A#PKey, SyncMessage[A]], + appService: ApplicationService[F, A], + blockExecutor: BlockExecutor[F, N, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + syncPipe: SyncPipe[F, A]#Right, + getState: F[ProtocolState[A]], + incomingFiberMap: FiberMap[F, A#PKey], + rpcTracker: RPCTracker[F, SyncMessage[A]] +)(implicit tracers: SyncTracers[F, A], storeRunner: KVStoreRunner[F, N]) { + import SyncMessage._ + + type BlockSync = SyncService.BlockSynchronizerWithFiberMap[F, N, A] + + private def protocolStatus: F[Status[A]] = + getState.map { state => + Status(state.viewNumber, state.prepareQC, state.commitQC) + } + + /** Request a block from a peer. */ + private def getBlock(from: A#PKey, blockHash: A#Hash): F[Option[A#Block]] = { + for { + requestId <- RequestId[F] + request = GetBlockRequest(requestId, blockHash) + maybeResponse <- sendRequest(from, request) + } yield maybeResponse.map(_.block) + } + + /** Request the status of a peer. */ + private def getStatus(from: A#PKey): F[Option[Status[A]]] = + if (from == publicKey) { + protocolStatus.map(_.some) + } else { + for { + requestId <- RequestId[F] + request = GetStatusRequest[A](requestId) + maybeResponse <- sendRequest(from, request) + } yield maybeResponse.map(_.status) + } + + /** Send a request to the peer and track the response. + * + * Returns `None` if we're not connected or the request times out. + */ + private def sendRequest[ + Req <: RPCMessageCompanion#Request, + Res <: RPCMessageCompanion#Response + ](from: A#PKey, request: Req)(implicit + ev1: Req <:< SyncMessage[A] with SyncMessage.Request, + ev2: RPCPair.Aux[Req, Res], + ct: ClassTag[Res] + ): F[Option[Res]] = { + for { + join <- rpcTracker.register[Req, Res](request) + _ <- network.sendMessage(from, request) + res <- join + _ <- tracers.requestTimeout(from -> request).whenA(res.isEmpty) + } yield res + } + + /** Process incoming network messages. */ + private def processNetworkMessages: F[Unit] = { + // TODO (PM-3186): Rate limiting per source. + network.incomingMessages + .mapEval[Unit] { case ConnectionHandler.MessageReceived(from, message) => + // Handle on a fiber dedicated to the source. + incomingFiberMap + .submit(from) { + processNetworkMessage(from, message) + } + .attemptNarrow[FiberMap.QueueFullException] + .flatMap { + case Right(_) => ().pure[F] + case Left(_) => tracers.queueFull(from) + } + } + .completedL + } + + /** Process one incoming network message. + * + * It's going to be executed on a fiber. + */ + private def processNetworkMessage( + from: A#PKey, + message: SyncMessage[A] + ): F[Unit] = { + val process = message match { + case GetStatusRequest(requestId) => + protocolStatus.flatMap { status => + network.sendMessage( + from, + GetStatusResponse(requestId, status) + ) + } + + case GetBlockRequest(requestId, blockHash) => + storeRunner + .runReadOnly { + blockStorage.get(blockHash) + } + .flatMap { + case None => + ().pure[F] + case Some(block) => + network.sendMessage( + from, + GetBlockResponse(requestId, block) + ) + } + + case response: SyncMessage.Response => + rpcTracker.complete(response).flatMap { + case Right(ok) => + tracers.responseIgnored((from, response, None)).whenA(!ok) + case Left(ex) => + tracers.responseIgnored((from, response, Some(ex))) + } + } + + process.handleErrorWith { case NonFatal(ex) => + tracers.error(ex) + } + } + + /** Read Requests from the SyncPipe and send Responses. + * + * These are coming from the `ConsensusService` asking for a + * `Prepare` message to be synchronized with the sender, or + * for the view to be synchronized with the whole federation. + */ + private def processSyncPipe( + makeBlockSync: F[BlockSync], + viewSynchronizer: ViewSynchronizer[F, A] + ): F[Unit] = + syncPipe.receive.consume.use { consumer => + def loop( + blockSync: BlockSync, + lastSyncedViewNumber: ViewNumber + ): F[Unit] = { + consumer.pull.flatMap { + case Right(SyncPipe.PrepareRequest(_, prepare)) + if prepare.viewNumber < lastSyncedViewNumber => + // We have already synced to a Commit Q.C. higher than this old PrepareRequest. + loop(blockSync, lastSyncedViewNumber) + + case Right(SyncPipe.StatusRequest(viewNumber)) + if viewNumber < lastSyncedViewNumber => + // We have already synced higher than this old StatusRequest. + loop(blockSync, lastSyncedViewNumber) + + case Right(request @ SyncPipe.PrepareRequest(_, _)) => + handlePrepareRequest(blockSync, request) >> + loop(blockSync, lastSyncedViewNumber) + + case Right(request @ SyncPipe.StatusRequest(_)) => + handleStatusRequest( + makeBlockSync, + blockSync, + viewSynchronizer, + request + ).flatMap { + (loop _).tupled + } + + case Left(maybeError) => + blockSync.fiberMapRelease >> + maybeError.fold(().pure[F])(Sync[F].raiseError(_)) + } + } + + makeBlockSync.flatMap { blockSync => + loop(blockSync, ViewNumber(0)) + } + } + + /** Sync with the sender up to the High Q.C. it sent, then validate the prepared block. + * + * This is done in the background, while further requests are taken from the pipe. + */ + private def handlePrepareRequest( + blockSync: BlockSync, + request: SyncPipe.PrepareRequest[A] + ): F[Unit] = { + val sender = request.sender + val prepare = request.prepare + // It is enough to respond to the last block positively, it will indicate + // that the whole range can be executed later (at that point from storage). + // If the same leader is sending us newer proposals, we can ignore the + // previous pepared blocks - they are either part of the new Q.C., + // in which case they don't need to be validated, or they have not + // gathered enough votes, and been superseded by a new proposal. + blockSync.fiberMap.cancelQueue(sender) >> + blockSync.fiberMap + .submit(sender) { + blockSync.synchronizer.sync(sender, prepare.highQC) >> + validateBlock(prepare.block) >>= { + case Some(isValid) => + syncPipe.send(SyncPipe.PrepareResponse(request, isValid)) + case None => + // We didn't have data to decide validity in time; not responding. + ().pure[F] + } + } + .void + } + + /** Validate the prepared block after the parent has been downloaded. */ + private def validateBlock(block: A#Block): F[Option[Boolean]] = { + // Short circuiting validation. + def runChecks(checks: F[Option[Boolean]]*) = + checks.reduce[F[Option[Boolean]]] { case (a, b) => + a.flatMap { + case Some(true) => b + case other => other.pure[F] + } + } + + runChecks( + storeRunner.runReadOnly { + blockStorage + .get(Block[A].parentBlockHash(block)) + .map(_.map(Block[A].isParentOf(_, block))) + }, + Block[A].isValid(block).some.pure[F], + appService.validateBlock(block) + ) + } + + /** Shut down the any outstanding block downloads, sync the view, + * then create another block synchronizer instance to resume with. + */ + private def handleStatusRequest( + makeBlockSync: F[BlockSync], + blockSync: BlockSync, + viewSynchronizer: ViewSynchronizer[F, A], + request: SyncPipe.StatusRequest + ): F[(BlockSync, ViewNumber)] = + for { + // Cancel all outstanding block syncing. + _ <- blockSync.fiberMapRelease + // The block synchronizer is still usable. + viewNumber <- syncStatus( + blockSync.synchronizer, + viewSynchronizer + ).handleErrorWith { case NonFatal(ex) => + tracers.error(ex).as(request.viewNumber) + } + // Create a fresh fiber and block synchronizer instance. + // When the previous goes out of scope, its ephemeral storage is freed. + newBlockSync <- makeBlockSync + } yield (newBlockSync, viewNumber) + + /** Get the latest status of federation members, download the corresponding block + * and prune all existing block history, making the latest Commit Q.C. the new + * root in the block tree. + * + * This is done in the foreground, no further requests are taken from the pipe. + */ + private def syncStatus( + blockSynchronizer: BlockSynchronizer[F, N, A], + viewSynchronizer: ViewSynchronizer[F, A] + ): F[ViewNumber] = + for { + // Sync to the latest Commit Q.C. + federationStatus <- viewSynchronizer.sync + status = federationStatus.status + + // Download the block in the Commit Q.C. + block <- blockSynchronizer + .getBlockFromQuorumCertificate( + federationStatus.sources, + status.commitQC + ) + .rethrow + + // Sync any application specific state, e.g. a ledger, + // then potentially prune old blocks from the storage. + _ <- blockExecutor.syncState(federationStatus.sources, block) + + // Tell the ConsensusService about the new Status. + _ <- syncPipe.send(SyncPipe.StatusResponse(status)) + } yield status.viewNumber +} + +object SyncService { + + /** Create a `SyncService` instance and start processing messages + * in the background, shutting processing down when the resource is + * released. + */ + def apply[ + F[_]: Concurrent: ContextShift: Timer: Parallel, + N, + A <: Agreement: Block: Signing + ]( + publicKey: A#PKey, + federation: Federation[A#PKey], + network: Network[F, A#PKey, SyncMessage[A]], + appService: ApplicationService[F, A], + blockExecutor: BlockExecutor[F, N, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + syncPipe: SyncPipe[F, A]#Right, + getState: F[ProtocolState[A]], + timeout: FiniteDuration = 10.seconds + )(implicit + tracers: SyncTracers[F, A], + storeRunner: KVStoreRunner[F, N] + ): Resource[F, SyncService[F, N, A]] = + // TODO (PM-3186): Add capacity as part of rate limiting. + for { + incomingFiberMap <- FiberMap[F, A#PKey]() + rpcTracker <- Resource.liftF { + RPCTracker[F, SyncMessage[A]](timeout) + } + service = new SyncService( + publicKey, + network, + appService, + blockExecutor, + blockStorage, + viewStateStorage, + syncPipe, + getState, + incomingFiberMap, + rpcTracker + ) + + blockSync = for { + (syncFiberMap, syncFiberMapRelease) <- FiberMap[F, A#PKey]().allocated + blockSynchronizer <- BlockSynchronizer[F, N, A]( + publicKey, + federation, + blockStorage, + service.getBlock + ) + } yield BlockSynchronizerWithFiberMap( + blockSynchronizer, + syncFiberMap, + syncFiberMapRelease + ) + + viewSynchronizer = new ViewSynchronizer[F, A]( + federation, + service.getStatus + ) + + _ <- Concurrent[F].background { + service.processNetworkMessages + } + _ <- Concurrent[F].background { + service.processSyncPipe(blockSync, viewSynchronizer) + } + } yield service + + /** The `SyncService` can be in two modes: either we're in sync with the federation + * and downloading the odd missing block every now and then, or we are out of sync, + * in which case we need to ask everyone to find out what the current view number + * is, and then jump straight to the latest Commit Quorum Certificate. + * + * Our implementation assumes that this is always supported by the application. + * + * When we go from block sync to view sync, the block syncs happening in the + * background on the fiber ap in this class are canceled, and the synchronizer + * instance with its ephemeral storage is discarded. + */ + case class BlockSynchronizerWithFiberMap[F[_], N, A <: Agreement]( + synchronizer: BlockSynchronizer[F, N, A], + fiberMap: FiberMap[F, A#PKey], + fiberMapRelease: F[Unit] + ) +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/execution/BlockExecutor.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/execution/BlockExecutor.scala new file mode 100644 index 00000000..63cbe831 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/execution/BlockExecutor.scala @@ -0,0 +1,303 @@ +package io.iohk.metronome.hotstuff.service.execution + +import cats.implicits._ +import cats.data.{NonEmptyList, NonEmptyVector} +import cats.effect.{Sync, Concurrent, ContextShift, Resource} +import cats.effect.concurrent.Semaphore +import io.iohk.metronome.hotstuff.service.ApplicationService +import io.iohk.metronome.hotstuff.service.storage.{ + BlockStorage, + ViewStateStorage +} +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + Block, + Effect, + QuorumCertificate +} +import io.iohk.metronome.hotstuff.service.tracing.ConsensusTracers +import io.iohk.metronome.storage.KVStoreRunner +import monix.catnap.ConcurrentQueue + +/** The `BlockExecutor` receives ranges of committed blocks from the + * `ConsensusService` and carries out their effects, marking the last + * executed block in the `ViewStateStorage`, so that we can resume + * from where we left off last time after a restart. + * + * It delegates other state updates to the `ApplicationService`. + * + * The `BlockExecutor` is prepared for gaps to appear in the ranges, + * which happens if the node is out of sync with the federation and + * needs to jump ahead. + */ +class BlockExecutor[F[_]: Sync, N, A <: Agreement: Block]( + appService: ApplicationService[F, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A], + executionQueue: ConcurrentQueue[F, Effect.ExecuteBlocks[A]], + executionSemaphore: Semaphore[F] +)(implicit tracers: ConsensusTracers[F, A], storeRunner: KVStoreRunner[F, N]) { + + /** Add a newly committed batch of blocks to the execution queue. */ + def enqueue(effect: Effect.ExecuteBlocks[A]): F[Unit] = + executionQueue.offer(effect) + + /** Fast forward state to a given block. + * + * This operation is delegated to the `BlockExecutor` so that it can make sure + * that it's not executing other blocks at the same time. + */ + def syncState( + sources: NonEmptyVector[A#PKey], + block: A#Block + ): F[Unit] = + executionSemaphore.withPermit { + for { + // Sync any application specific state, e.g. a ledger. + // Do this before we prune the existing blocks and set the new root. + canPrune <- appService.syncState(sources, block) + // Prune the block store from earlier blocks that are no longer traversable. + _ <- fastForwardStorage(block, canPrune) + } yield () + } + + /** Execute blocks in order, updating pesistent storage along the way. */ + private def executeBlocks: F[Unit] = { + executionQueue.poll + .flatMap { case Effect.ExecuteBlocks(lastCommittedBlockHash, commitQC) => + // Retrieve the blocks from the storage from the last executed + // to the one in the Quorum Certificate and tell the application + // to execute them one by one. Update the persistent view state + // after reach execution to remember which blocks we have truly + // done. + // Protect the whole thing with a semaphore from `syncState` being + // carried out at the same time. + executionSemaphore.withPermit { + for { + lastExecutedBlockHash <- getLastExecutedBlockHash + blockHashes <- getBlockPath( + lastExecutedBlockHash, + lastCommittedBlockHash, + commitQC + ) + _ <- blockHashes match { + case _ :: newBlockHashes => + tryExecuteBatch(newBlockHashes, commitQC, lastExecutedBlockHash) + case Nil => + ().pure[F] + } + } yield () + } + } >> executeBlocks + } + + /** Read whatever was the last executed block that we peristed, + * either by doing individual execution or state sync. + */ + private def getLastExecutedBlockHash: F[A#Hash] = + storeRunner.runReadOnly { + viewStateStorage.getLastExecutedBlockHash + } + + /** Update the last executed block hash, unless something else updated it + * while we were executing blocks. This shouldn't happen if we used the + * executor to carry out the state sync within the semaphore. + */ + private def setLastExecutedBlockHash( + blockHash: A#Hash, + lastExecutedBlockHash: A#Hash + ): F[Boolean] = + storeRunner.runReadWrite { + viewStateStorage + .compareAndSetLastExecutedBlockHash( + blockHash, + lastExecutedBlockHash + ) + } + + /** Get the more complete path. We may not have the last executed block any more. + * + * The first hash in the return value is a block that has already been executed. + */ + private def getBlockPath( + lastExecutedBlockHash: A#Hash, + lastCommittedBlockHash: A#Hash, + commitQC: QuorumCertificate[A] + ): F[List[A#Hash]] = { + def readPath(ancestorBlockHash: A#Hash) = + storeRunner + .runReadOnly { + blockStorage.getPathFromAncestor( + ancestorBlockHash, + commitQC.blockHash + ) + } + + readPath(lastExecutedBlockHash) + .flatMap { + case Nil => + readPath(lastCommittedBlockHash) + case path => + path.pure[F] + } + } + + /** Try to execute a batch of newly committed blocks. + * + * The last executed block hash is used to track that it hasn't + * been modified by the jump-ahead state sync mechanism while + * we were executing blocks. + * + * In general we cannot expect to be able to cancel an ongoing execution, + * it may be in the middle of carrying out some real-world effects that + * don't support cancellation. We use the semaphore to protect against + * race conditions between executing blocks here and the fast-forward + * synchroniser making changes to state. + */ + private def tryExecuteBatch( + newBlockHashes: List[A#Hash], + commitQC: QuorumCertificate[A], + lastExecutedBlockHash: A#Hash + ): F[Unit] = { + def loop( + newBlockHashes: List[A#Hash], + lastExecutedBlockHash: A#Hash + ): F[Unit] = + newBlockHashes match { + case Nil => + ().pure[F] + + case blockHash :: nextBlockHashes => + executeBlock( + blockHash, + commitQC, + NonEmptyList(blockHash, nextBlockHashes), + lastExecutedBlockHash + ).attempt.flatMap { + case Left(ex) => + // If a block fails, return what we managed to do so far, + // so we can re-attempt it next time if the block is still + // available in the storage. + tracers + .error(s"Error executing block $blockHash", ex) + + case Right(None) => + // Either the block couldn't be found, or the last executed + // hash changed, but something suggests that we should not + // try to execute more of this batch. + nextBlockHashes.traverse(tracers.executionSkipped(_)).void + + case Right(Some(nextLastExecutedBlockHash)) => + loop(nextBlockHashes, nextLastExecutedBlockHash) + } + } + + loop(newBlockHashes, lastExecutedBlockHash) + } + + /** Execute the next block in line and update the view state. + * + * The last executed block hash is only updated if the application + * indicates that it has persisted the results, and if no other + * changes have been made to it outside this loop. The execution + * result carries the new last executed block hash to use in the + * next iteration, or `None` if we should abandon the execution. + */ + private def executeBlock( + blockHash: A#Hash, + commitQC: QuorumCertificate[A], + commitPath: NonEmptyList[A#Hash], + lastExecutedBlockHash: A#Hash + ): F[Option[A#Hash]] = { + assert(commitPath.head == blockHash) + assert(commitPath.last == commitQC.blockHash) + + storeRunner.runReadOnly { + blockStorage.get(blockHash) + } flatMap { + case None => + tracers.executionSkipped(blockHash).as(none) + + case Some(block) => + for { + isPersisted <- appService.executeBlock(block, commitQC, commitPath) + _ <- tracers.blockExecuted(blockHash) + + maybeLastExecutedBlockHash <- + if (!isPersisted) { + // Keep the last for the next compare and set below. + lastExecutedBlockHash.some.pure[F] + } else { + // Check that nothing else changed the view state, + // which should be true as long as we use the semaphore. + // Otherwise it would be up to the `ApplicationService` to + // take care of isolation, and check that the block being + // executed is the one we expected. + setLastExecutedBlockHash(blockHash, lastExecutedBlockHash).map { + case true => blockHash.some + case false => none + } + } + } yield maybeLastExecutedBlockHash + } + } + + /** Replace the state we have persisted with what we synced with the federation. + * + * Prunes old blocks, the Commit Q.C. will be the new root. + */ + private def fastForwardStorage( + block: A#Block, + canPrune: Boolean + ): F[Unit] = { + val blockHash = Block[A].blockHash(block) + + val prune = for { + viewState <- viewStateStorage.getBundle.lift + // Prune old data, but keep the new block. + // Traversing from the old root, because the + // new block is probably not connected to it. + _ <- blockStorage.purgeTree( + viewState.rootBlockHash, + keep = blockHash.some + ) + _ <- viewStateStorage.setRootBlockHash(blockHash) + } yield () + + val query = for { + // Insert the new block. + _ <- blockStorage.put(block) + _ <- prune.whenA(canPrune) + // Considering the committed block as executed, we have its state already. + _ <- viewStateStorage.setLastExecutedBlockHash(blockHash) + } yield () + + storeRunner.runReadWrite(query) + } +} + +object BlockExecutor { + def apply[F[_]: Concurrent: ContextShift, N, A <: Agreement: Block]( + appService: ApplicationService[F, A], + blockStorage: BlockStorage[N, A], + viewStateStorage: ViewStateStorage[N, A] + )(implicit + tracers: ConsensusTracers[F, A], + storeRunner: KVStoreRunner[F, N] + ): Resource[F, BlockExecutor[F, N, A]] = for { + executionQueue <- Resource.liftF { + ConcurrentQueue[F].unbounded[Effect.ExecuteBlocks[A]](None) + } + executionSemaphore <- Resource.liftF(Semaphore[F](1)) + executor = new BlockExecutor[F, N, A]( + appService, + blockStorage, + viewStateStorage, + executionQueue, + executionSemaphore + ) + _ <- Concurrent[F].background { + executor.executeBlocks + } + } yield executor +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/DuplexMessage.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/DuplexMessage.scala new file mode 100644 index 00000000..92015dc4 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/DuplexMessage.scala @@ -0,0 +1,24 @@ +package io.iohk.metronome.hotstuff.service.messages + +import io.iohk.metronome.hotstuff +import io.iohk.metronome.hotstuff.consensus.basic.Agreement + +/** Messages type to use in the networking layer if the use case has + * application specific message types, e.g. ledger synchronisation, + * not just the general BFT agreement (which could be enough if + * we need to execute all blocks to synchronize state). + */ +sealed trait DuplexMessage[A <: Agreement, M] + +object DuplexMessage { + + /** General BFT agreement message. */ + case class AgreementMessage[A <: Agreement]( + message: hotstuff.service.messages.HotStuffMessage[A] + ) extends DuplexMessage[A, Nothing] + + /** Application specific message. */ + case class ApplicationMessage[M]( + message: M + ) extends DuplexMessage[Nothing, M] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/HotStuffMessage.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/HotStuffMessage.scala new file mode 100644 index 00000000..90809bf9 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/HotStuffMessage.scala @@ -0,0 +1,23 @@ +package io.iohk.metronome.hotstuff.service.messages + +import io.iohk.metronome.hotstuff +import io.iohk.metronome.hotstuff.consensus.basic.Agreement + +/** Messages which are generic to any HotStuff BFT agreement. */ +sealed trait HotStuffMessage[A <: Agreement] + +object HotStuffMessage { + + /** Messages which are part of the basic HotStuff BFT algorithm itself. */ + case class ConsensusMessage[A <: Agreement]( + message: hotstuff.consensus.basic.Message[A] + ) extends HotStuffMessage[A] + + /** Messages that support the HotStuff BFT agreement but aren't part of + * the core algorithm, e.g. block and view number synchronisation. + */ + case class SyncMessage[A <: Agreement]( + message: hotstuff.service.messages.SyncMessage[A] + ) extends HotStuffMessage[A] + +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/SyncMessage.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/SyncMessage.scala new file mode 100644 index 00000000..bfea9c39 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/messages/SyncMessage.scala @@ -0,0 +1,41 @@ +package io.iohk.metronome.hotstuff.service.messages + +import io.iohk.metronome.core.messages.{RPCMessage, RPCMessageCompanion} +import io.iohk.metronome.hotstuff.consensus.basic.Agreement +import io.iohk.metronome.hotstuff.service.Status + +/** Messages needed to fully realise the HotStuff protocol, + * without catering for any application specific concerns. + */ +sealed trait SyncMessage[+A <: Agreement] { self: RPCMessage => } + +object SyncMessage extends RPCMessageCompanion { + case class GetStatusRequest[A <: Agreement]( + requestId: RequestId + ) extends SyncMessage[A] + with Request + + case class GetStatusResponse[A <: Agreement]( + requestId: RequestId, + status: Status[A] + ) extends SyncMessage[A] + with Response + + case class GetBlockRequest[A <: Agreement]( + requestId: RequestId, + blockHash: A#Hash + ) extends SyncMessage[A] + with Request + + case class GetBlockResponse[A <: Agreement]( + requestId: RequestId, + block: A#Block + ) extends SyncMessage[A] + with Response + + implicit def getBlockPair[A <: Agreement] = + pair[GetBlockRequest[A], GetBlockResponse[A]] + + implicit def getStatusPair[A <: Agreement] = + pair[GetStatusRequest[A], GetStatusResponse[A]] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/pipes/SyncPipe.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/pipes/SyncPipe.scala new file mode 100644 index 00000000..61a9f03a --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/pipes/SyncPipe.scala @@ -0,0 +1,58 @@ +package io.iohk.metronome.hotstuff.service.pipes + +import cats.effect.{Concurrent, ContextShift} +import io.iohk.metronome.core.Pipe +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, Message} +import io.iohk.metronome.hotstuff.service.Status + +object SyncPipe { + + sealed trait Request[+A <: Agreement] + sealed trait Response[+A <: Agreement] + + /** Request the synchronization component to download + * any missing dependencies up to the High Q.C., + * perform any application specific validation, + * including the block in the `Prepare` message, + * and persist the blocks up to, but not including + * the block in the `Prepare` message. + * + * This is because the block being prepared is + * subject to further validation and voting, + * while the one in the High Q.C. has gathered + * a quorum from the federation. + */ + case class PrepareRequest[A <: Agreement]( + sender: A#PKey, + prepare: Message.Prepare[A] + ) extends Request[A] + + /** Respond with the outcome of whether the + * block we're being asked to prepare is + * valid, according to the application rules. + */ + case class PrepareResponse[A <: Agreement]( + request: PrepareRequest[A], + isValid: Boolean + ) extends Response[A] + + /** Request that the view state is synchronized with the whole federation, + * including downloading the block and state corresponding to the latest + * Commit Q.C. + * + * The eventual response should contain the new view status to be applied + * on the protocol state. + */ + case class StatusRequest(viewNumber: ViewNumber) extends Request[Nothing] + + /** Response with the new status to resume the protocol from, after the + * state has been synchronized up to the included Commit Q.C. + */ + case class StatusResponse[A <: Agreement]( + status: Status[A] + ) extends Response[A] + + def apply[F[_]: Concurrent: ContextShift, A <: Agreement]: F[SyncPipe[F, A]] = + Pipe[F, SyncPipe.Request[A], SyncPipe.Response[A]] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/pipes/package.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/pipes/package.scala new file mode 100644 index 00000000..ea4defe7 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/pipes/package.scala @@ -0,0 +1,11 @@ +package io.iohk.metronome.hotstuff.service + +import io.iohk.metronome.core.Pipe +import io.iohk.metronome.hotstuff.consensus.basic.Agreement + +package object pipes { + + /** Communication pipe with the block synchronization and validation component. */ + type SyncPipe[F[_], A <: Agreement] = + Pipe[F, SyncPipe.Request[A], SyncPipe.Response[A]] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/storage/BlockStorage.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/storage/BlockStorage.scala new file mode 100644 index 00000000..4d658e6f --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/storage/BlockStorage.scala @@ -0,0 +1,33 @@ +package io.iohk.metronome.hotstuff.service.storage + +import io.iohk.metronome.storage.{KVCollection, KVTree} +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, Block} + +/** Storage for blocks that maintains parent-child relationships as well, + * to facilitate tree traversal and pruning. + * + * It is assumed that the application maintains some pointers into the tree + * where it can start traversing from, e.g. the last Commit Quorum Certificate + * would point at a block hash which would serve as the entry point. + */ +class BlockStorage[N, A <: Agreement: Block]( + blockColl: KVCollection[N, A#Hash, A#Block], + blockMetaColl: KVCollection[N, A#Hash, KVTree.NodeMeta[A#Hash]], + parentToChildrenColl: KVCollection[N, A#Hash, Set[A#Hash]] +) extends KVTree[N, A#Hash, A#Block]( + blockColl, + blockMetaColl, + parentToChildrenColl + )(BlockStorage.node[A]) + +object BlockStorage { + implicit def node[A <: Agreement: Block]: KVTree.Node[A#Hash, A#Block] = + new KVTree.Node[A#Hash, A#Block] { + override def key(value: A#Block): A#Hash = + Block[A].blockHash(value) + override def parentKey(value: A#Block): A#Hash = + Block[A].parentBlockHash(value) + override def height(value: A#Block): Long = + Block[A].height(value) + } +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/storage/ViewStateStorage.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/storage/ViewStateStorage.scala new file mode 100644 index 00000000..4cb166eb --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/storage/ViewStateStorage.scala @@ -0,0 +1,172 @@ +package io.iohk.metronome.hotstuff.service.storage + +import cats.implicits._ +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + QuorumCertificate, + Phase +} +import io.iohk.metronome.storage.{KVStore, KVStoreRead} +import scodec.{Codec, Encoder, Decoder} + +class ViewStateStorage[N, A <: Agreement] private ( + namespace: N +)(implicit + keys: ViewStateStorage.Keys[A], + kvn: KVStore.Ops[N], + kvrn: KVStoreRead.Ops[N], + codecVN: Codec[ViewNumber], + codecQC: Codec[QuorumCertificate[A]], + codecH: Codec[A#Hash] +) { + import keys.Key + + private def put[V: Encoder](key: Key[V], value: V) = + KVStore[N].put[Key[V], V](namespace, key, value) + + private def read[V: Decoder](key: Key[V]): KVStoreRead[N, V] = + KVStoreRead[N].read[Key[V], V](namespace, key).map(_.get) + + def setViewNumber(viewNumber: ViewNumber): KVStore[N, Unit] = + put(Key.ViewNumber, viewNumber) + + def setQuorumCertificate(qc: QuorumCertificate[A]): KVStore[N, Unit] = + qc.phase match { + case Phase.Prepare => + put(Key.PrepareQC, qc) + case Phase.PreCommit => + put(Key.LockedQC, qc) + case Phase.Commit => + put(Key.CommitQC, qc) + } + + def setLastExecutedBlockHash(blockHash: A#Hash): KVStore[N, Unit] = + put(Key.LastExecutedBlockHash, blockHash) + + /** Set `LastExecutedBlockHash` to `blockHash` if it's still what it was before. */ + def compareAndSetLastExecutedBlockHash( + blockHash: A#Hash, + lastExecutedBlockHash: A#Hash + ): KVStore[N, Boolean] = + read(Key.LastExecutedBlockHash).lift.flatMap { current => + if (current == lastExecutedBlockHash) { + setLastExecutedBlockHash(blockHash).as(true) + } else { + KVStore[N].pure(false) + } + } + + def setRootBlockHash(blockHash: A#Hash): KVStore[N, Unit] = + put(Key.RootBlockHash, blockHash) + + val getBundle: KVStoreRead[N, ViewStateStorage.Bundle[A]] = + ( + read(Key.ViewNumber), + read(Key.PrepareQC), + read(Key.LockedQC), + read(Key.CommitQC), + read(Key.LastExecutedBlockHash), + read(Key.RootBlockHash) + ).mapN(ViewStateStorage.Bundle.apply[A] _) + + val getLastExecutedBlockHash: KVStoreRead[N, A#Hash] = + read(Key.LastExecutedBlockHash) +} + +object ViewStateStorage { + + /** Storing elements of the view state individually under separate keys, + * because they get written independently. + */ + trait Keys[A <: Agreement] { + sealed abstract class Key[V](private val code: Int) + object Key { + case object ViewNumber extends Key[ViewNumber](0) + case object PrepareQC extends Key[QuorumCertificate[A]](1) + case object LockedQC extends Key[QuorumCertificate[A]](2) + case object CommitQC extends Key[QuorumCertificate[A]](3) + case object LastExecutedBlockHash extends Key[A#Hash](4) + case object RootBlockHash extends Key[A#Hash](5) + + implicit def encoder[V]: Encoder[Key[V]] = + scodec.codecs.uint8.contramap[Key[V]](_.code) + } + } + + /** The state of consensus that needs to be persisted between restarts. + * + * The fields are a subset of the `ProtocolState` but have a slightly + * different life cylce, e.g. `lastExecutedBlockHash` is only updated + * when the blocks are actually executed, which happens asynchronously. + */ + case class Bundle[A <: Agreement]( + viewNumber: ViewNumber, + prepareQC: QuorumCertificate[A], + lockedQC: QuorumCertificate[A], + commitQC: QuorumCertificate[A], + lastExecutedBlockHash: A#Hash, + rootBlockHash: A#Hash + ) { + assert(prepareQC.phase == Phase.Prepare) + assert(lockedQC.phase == Phase.PreCommit) + assert(commitQC.phase == Phase.Commit) + } + object Bundle { + + /** Convenience method reflecting the expectation that the signature + * in the genesis Q.C. will not depend on the phase, just the genesis + * hash. + */ + def fromGenesisQC[A <: Agreement](genesisQC: QuorumCertificate[A]) = + Bundle[A]( + viewNumber = genesisQC.viewNumber, + prepareQC = genesisQC.copy[A](phase = Phase.Prepare), + lockedQC = genesisQC.copy[A](phase = Phase.PreCommit), + commitQC = genesisQC.copy[A](phase = Phase.Commit), + lastExecutedBlockHash = genesisQC.blockHash, + rootBlockHash = genesisQC.blockHash + ) + } + + /** Create a ViewStateStorage instance by pre-loading it with the genesis, + * unless it already has data. + */ + def apply[N, A <: Agreement]( + namespace: N, + genesis: Bundle[A] + )(implicit + codecVN: Codec[ViewNumber], + codecQC: Codec[QuorumCertificate[A]], + codecH: Codec[A#Hash] + ): KVStore[N, ViewStateStorage[N, A]] = { + implicit val kvn = KVStore.instance[N] + implicit val kvrn = KVStoreRead.instance[N] + implicit val keys = new Keys[A] {} + import keys.Key + + def setDefault[V](default: V): Option[V] => Option[V] = + (current: Option[V]) => current orElse Some(default) + + for { + _ <- KVStore[N].alter(namespace, Key.ViewNumber)( + setDefault(genesis.viewNumber) + ) + _ <- KVStore[N].alter(namespace, Key.PrepareQC)( + setDefault(genesis.prepareQC) + ) + _ <- KVStore[N].alter(namespace, Key.LockedQC)( + setDefault(genesis.lockedQC) + ) + _ <- KVStore[N].alter(namespace, Key.CommitQC)( + setDefault(genesis.commitQC) + ) + _ <- KVStore[N].alter(namespace, Key.LastExecutedBlockHash)( + setDefault(genesis.lastExecutedBlockHash) + ) + _ <- KVStore[N].alter(namespace, Key.RootBlockHash)( + setDefault(genesis.rootBlockHash) + ) + } yield new ViewStateStorage[N, A](namespace) + } +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/sync/BlockSynchronizer.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/sync/BlockSynchronizer.scala new file mode 100644 index 00000000..3f5f1ee7 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/sync/BlockSynchronizer.scala @@ -0,0 +1,313 @@ +package io.iohk.metronome.hotstuff.service.sync + +import cats.implicits._ +import cats.data.NonEmptyVector +import cats.effect.{Sync, Timer, Concurrent, ContextShift} +import cats.effect.concurrent.Semaphore +import io.iohk.metronome.hotstuff.consensus.Federation +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + QuorumCertificate, + Block +} +import io.iohk.metronome.hotstuff.service.storage.BlockStorage +import io.iohk.metronome.storage.{InMemoryKVStore, KVStoreRunner} +import scala.concurrent.duration._ +import scala.util.Random +import scala.util.control.NoStackTrace + +/** The job of the `BlockSynchronizer` is to procure missing blocks when a `Prepare` + * message builds on a High Q.C. that we don't have. + * + * It will walk backwards, asking for the ancestors until we find one that we already + * have in persistent storage, then append blocks to the storage in the opposite order. + * + * Since the final block has a Quorum Certificate, there's no need to validate the + * ancestors, assuming an honest majority in the federation. The only validation we + * need to do is hash checks to make sure we're getting the correct blocks. + * + * The synchronizer keeps the tentative blocks in memory until they can be connected + * to the persistent storage. We assume that we never have to download the block history + * back until genesis, but rather that the application will always have support for + * syncing to any given block and its associated state, to catch up after spending + * a long time offline. Once that happens the block history should be pruneable. + */ +class BlockSynchronizer[F[_]: Sync: Timer, N, A <: Agreement: Block]( + publicKey: A#PKey, + federation: Federation[A#PKey], + blockStorage: BlockStorage[N, A], + getBlock: BlockSynchronizer.GetBlock[F, A], + inMemoryStore: KVStoreRunner[F, N], + semaphore: Semaphore[F], + retryTimeout: FiniteDuration = 5.seconds +)(implicit storeRunner: KVStoreRunner[F, N]) { + import BlockSynchronizer.DownloadFailedException + + private val otherPublicKeys = + federation.publicKeys.filterNot(_ == publicKey) + + // We must take care not to insert blocks into storage and risk losing + // the pointer to them in a restart, hence keeping the unfinished tree + // in memory until we find a parent we do have in storage, then + // insert them in the opposite order. + + /** Download all blocks up to the one included in the Quorum Certificate. + * + * Only expected to be called once per sender at the same time, otherwise + * it may request the same ancestor block multiple times concurrently. + * + * This could be managed with internal queueing, but not having that should + * make it easier to cancel all calling fibers and discard the synchronizer + * instance and its in-memory store, do state syncing, then replace it with + * a fresh one. + */ + def sync( + sender: A#PKey, + quorumCertificate: QuorumCertificate[A] + ): F[Unit] = + for { + path <- download(sender, quorumCertificate.blockHash, Nil) + _ <- persist(quorumCertificate.blockHash, path) + } yield () + + /** Download the block in the Quorum Certificate without ancestors. + * + * Return it without being persisted. + * + * Unlike `sync`, which is expected to be canceled if consensus times out, + * or be satisfied by alternative downloads happening concurrently, this + * method returns and error if it cannot download the block after a certain + * number of attempts, from any of the sources. This is becuause its primary + * use is during state syncing where this is the only operation, and if for + * any reason the block would be gone from everyone honest members' storage, + * we have to try something else. + */ + def getBlockFromQuorumCertificate( + sources: NonEmptyVector[A#PKey], + quorumCertificate: QuorumCertificate[A] + ): F[Either[DownloadFailedException[A], A#Block]] = { + val otherSources = sources.filterNot(_ == publicKey).toList + + def loop( + sources: List[A#PKey] + ): F[Either[DownloadFailedException[A], A#Block]] = { + sources match { + case Nil => + new DownloadFailedException( + quorumCertificate.blockHash, + sources.toVector + ).asLeft[A#Block].pure[F] + + case source :: alternatives => + getAndValidateBlock(source, quorumCertificate.blockHash, otherSources) + .flatMap { + case None => + loop(alternatives) + case Some(block) => + block.asRight[DownloadFailedException[A]].pure[F] + } + } + } + + storeRunner + .runReadOnly { + blockStorage.get(quorumCertificate.blockHash) + } + .flatMap { + case None => loop(Random.shuffle(otherSources)) + case Some(block) => block.asRight[DownloadFailedException[A]].pure[F] + } + } + + /** Download a block and all of its ancestors into the in-memory block store. + * + * Returns the path from the greatest ancestor that had to be downloaded + * to the originally requested block, so that we can persist them in that order. + * + * The path is maintained separately from the in-memory store in case another + * ongoing download would re-insert something on a path already partially removed + * resulting in a forest that cannot be traversed fully. + */ + private def download( + from: A#PKey, + blockHash: A#Hash, + path: List[A#Hash] + ): F[List[A#Hash]] = { + storeRunner + .runReadOnly { + blockStorage.contains(blockHash) + } + .flatMap { + case true => + path.pure[F] + + case false => + inMemoryStore + .runReadOnly { + blockStorage.get(blockHash) + } + .flatMap { + case Some(block) => + downloadParent(from, block, path) + + case None => + getAndValidateBlock(from, blockHash) + .flatMap { + case Some(block) => + inMemoryStore.runReadWrite { + blockStorage.put(block) + } >> downloadParent(from, block, path) + + case None => + Timer[F].sleep(retryTimeout) >> + download(from, blockHash, path) + } + } + } + } + + private def downloadParent( + from: A#PKey, + block: A#Block, + path: List[A#Hash] + ): F[List[A#Hash]] = { + val blockHash = Block[A].blockHash(block) + val parentBlockHash = Block[A].parentBlockHash(block) + download(from, parentBlockHash, blockHash :: path) + } + + /** Try downloading the block from the source and perform basic content validation. + * + * If the download fails, try random alternative sources in the federation. + */ + private def getAndValidateBlock( + from: A#PKey, + blockHash: A#Hash, + alternativeSources: Seq[A#PKey] = otherPublicKeys + ): F[Option[A#Block]] = { + def fetch(from: A#PKey) = + getBlock(from, blockHash) + .map { maybeBlock => + maybeBlock.filter { block => + Block[A].blockHash(block) == blockHash && + Block[A].isValid(block) + } + } + + def loop(sources: List[A#PKey]): F[Option[A#Block]] = + sources match { + case Nil => none.pure[F] + case from :: sources => + fetch(from).flatMap { + case None => loop(sources) + case block => block.pure[F] + } + } + + loop(List(from)).flatMap { + case None => + loop(Random.shuffle(alternativeSources.filterNot(_ == from).toList)) + case block => + block.pure[F] + } + } + + /** See how far we can go in memory from the original block hash we asked for, + * which indicates the blocks that no concurrent download has persisted yet, + * then persist the rest. + * + * Only doing one persist operation at a time to make sure there's no competition + * in the insertion order of the path elements among concurrent downloads. + */ + private def persist( + targetBlockHash: A#Hash, + path: List[A#Hash] + ): F[Unit] = + semaphore.withPermit { + inMemoryStore + .runReadOnly { + blockStorage.getPathFromRoot(targetBlockHash) + } + .flatMap { unpersisted => + persistAndClear(path, unpersisted.toSet) + } + } + + /** Move the blocks on the path from memory to persistent storage. + * + * `path` and `unpersisted` can be different when a concurrent download + * re-inserts some ancestor block into the in-memory store that another + * download has already removed during persistence. The `unpersisted` + * set only contains block that need to be inserted into persistent + * storage, but all `path` elements have to be visited to make sure + * nothing is left in the in-memory store, leaking memory. + */ + private def persistAndClear( + path: List[A#Hash], + unpersisted: Set[A#Hash] + ): F[Unit] = + path match { + case Nil => + ().pure[F] + + case blockHash :: rest => + inMemoryStore + .runReadWrite { + for { + maybeBlock <- blockStorage.get(blockHash).lift + // There could be other, overlapping paths being downloaded, + // but as long as they are on the call stack, it's okay to + // create a forest here. + _ <- blockStorage.deleteUnsafe(blockHash) + } yield maybeBlock + } + .flatMap { + case Some(block) if unpersisted(blockHash) => + storeRunner + .runReadWrite { + blockStorage.put(block) + } + case _ => + // Another download has already persisted it. + ().pure[F] + + } >> + persistAndClear(rest, unpersisted) + } +} + +object BlockSynchronizer { + + class DownloadFailedException[A <: Agreement]( + blockHash: A#Hash, + sources: Seq[A#PKey] + ) extends RuntimeException( + s"Failed to download block ${blockHash} from ${sources.size} sources." + ) + with NoStackTrace + + /** Send a network request to get a block. */ + type GetBlock[F[_], A <: Agreement] = (A#PKey, A#Hash) => F[Option[A#Block]] + + /** Create a block synchronizer resource. Stop any background downloads when released. */ + def apply[F[_]: Concurrent: ContextShift: Timer, N, A <: Agreement: Block]( + publicKey: A#PKey, + federation: Federation[A#PKey], + blockStorage: BlockStorage[N, A], + getBlock: GetBlock[F, A] + )(implicit + storeRunner: KVStoreRunner[F, N] + ): F[BlockSynchronizer[F, N, A]] = + for { + semaphore <- Semaphore[F](1) + inMemoryStore <- InMemoryKVStore[F, N] + synchronizer = new BlockSynchronizer[F, N, A]( + publicKey, + federation, + blockStorage, + getBlock, + inMemoryStore, + semaphore + ) + } yield synchronizer +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/sync/ViewSynchronizer.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/sync/ViewSynchronizer.scala new file mode 100644 index 00000000..0e6b8452 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/sync/ViewSynchronizer.scala @@ -0,0 +1,191 @@ +package io.iohk.metronome.hotstuff.service.sync + +import cats._ +import cats.implicits._ +import cats.effect.{Timer, Sync} +import cats.data.{NonEmptySeq, NonEmptyVector} +import io.iohk.metronome.core.Validated +import io.iohk.metronome.hotstuff.consensus.{Federation, ViewNumber} +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + Signing, + QuorumCertificate, + Phase +} +import io.iohk.metronome.hotstuff.service.Status +import io.iohk.metronome.hotstuff.service.tracing.SyncTracers +import scala.concurrent.duration._ +import io.iohk.metronome.hotstuff.consensus.basic.ProtocolError + +/** The job of the `ViewSynchronizer` is to ask the other federation members + * what their status is and figure out a view number we should be using. + * This is something we must do after startup, or if we have for some reason + * fallen out of sync with the rest of the federation. + */ +class ViewSynchronizer[F[_]: Sync: Timer: Parallel, A <: Agreement: Signing]( + federation: Federation[A#PKey], + getStatus: ViewSynchronizer.GetStatus[F, A], + retryTimeout: FiniteDuration = 5.seconds +)(implicit tracers: SyncTracers[F, A]) { + import ViewSynchronizer.{aggregateStatus, FederationStatus} + + /** Poll the federation members for the current status until we have gathered + * enough to make a decision, i.e. we have a quorum. + * + * Pick the highest Quorum Certificates from the gathered responses, but be + * more careful with he view number as these can be disingenuous. + * + * Try again until in one round we can gather all statuses from everyone. + */ + def sync: F[ViewSynchronizer.FederationStatus[A]] = { + federation.publicKeys.toVector + .parTraverse(getAndValidateStatus) + .flatMap { maybeStatuses => + val statusMap = (federation.publicKeys zip maybeStatuses).collect { + case (k, Some(s)) => k -> s + }.toMap + + tracers + .statusPoll(statusMap) + .as(statusMap) + } + .flatMap { + case statusMap if statusMap.size >= federation.quorumSize => + val statuses = statusMap.values.toList + val status = aggregateStatus(NonEmptySeq.fromSeqUnsafe(statuses)) + + // Returning everyone who responded so we always have a quorum sized set to talk to. + val sources = + NonEmptyVector.fromVectorUnsafe(statusMap.keySet.toVector) + + FederationStatus(status, sources).pure[F] + + case _ => + // We traced all responses, so we can detect if we're in an endless loop. + Timer[F].sleep(retryTimeout) >> sync + } + } + + private def getAndValidateStatus( + from: A#PKey + ): F[Option[Validated[Status[A]]]] = + getStatus(from).flatMap { + case None => + none.pure[F] + + case Some(status) => + validate(from, status) match { + case Left((error, hint)) => + tracers.invalidStatus(status, error, hint).as(none) + case Right(valid) => + valid.some.pure[F] + } + } + + private def validate( + from: A#PKey, + status: Status[A] + ): Either[ + (ProtocolError.InvalidQuorumCertificate[A], ViewSynchronizer.Hint), + Validated[Status[A]] + ] = + for { + _ <- validateQC(from, status.prepareQC)( + checkPhase(Phase.Prepare), + checkSignature, + checkVisible(status), + checkPrepareIsAfterCommit(status) + ) + _ <- validateQC(from, status.commitQC)( + checkPhase(Phase.Commit), + checkSignature, + checkVisible(status) + ) + } yield Validated[Status[A]](status) + + private def check(cond: Boolean, hint: => String) = + if (cond) none else hint.some + + private def checkPhase(phase: Phase)(qc: QuorumCertificate[A]) = + check(phase == qc.phase, s"Phase should be $phase.") + + private def checkSignature(qc: QuorumCertificate[A]) = + check(Signing[A].validate(federation, qc), "Invalid signature.") + + private def checkVisible(status: Status[A])(qc: QuorumCertificate[A]) = + check( + status.viewNumber >= qc.viewNumber, + "View number of status earlier than Q.C." + ) + + // This could be checked from either Q.C. perspective. + private def checkPrepareIsAfterCommit(status: Status[A]) = + (_: QuorumCertificate[A]) => + check( + status.prepareQC.viewNumber >= status.commitQC.viewNumber, + "Prepare Q.C. lower than Commit Q.C." + ) + + private def validateQC(from: A#PKey, qc: QuorumCertificate[A])( + checks: (QuorumCertificate[A] => Option[String])* + ) = + checks.toList.traverse { check => + check(qc) + .map { hint => + ProtocolError.InvalidQuorumCertificate(from, qc) -> hint + } + .toLeft(()) + } +} + +object ViewSynchronizer { + + /** Extra textual description for errors. */ + type Hint = String + + /** Send a network request to get the status of a replica. */ + type GetStatus[F[_], A <: Agreement] = A#PKey => F[Option[Status[A]]] + + /** Determines the best values to adopt: it picks the highest Prepare and + * Commit Quorum Certificates, and the median View Number. + * + * The former have signatures to prove their validity, but the latter could be + * gamed by adversarial actors, hence not using the highest value. + * Multiple rounds of peers trying to sync with each other and picking the + * median should make them converge in the end, unless an adversarial group + * actively tries to present different values to every honest node. + * + * Another candiate would be to use _mode_. + */ + def aggregateStatus[A <: Agreement]( + statuses: NonEmptySeq[Status[A]] + ): Status[A] = { + val prepareQC = statuses.map(_.prepareQC).maximumBy(_.viewNumber) + val commitQC = statuses.map(_.commitQC).maximumBy(_.viewNumber) + val viewNumber = + math.max(median(statuses.map(_.viewNumber)), prepareQC.viewNumber) + Status( + viewNumber = ViewNumber(viewNumber), + prepareQC = prepareQC, + commitQC = commitQC + ) + } + + /** Pick the middle from an ordered sequence of values. + * + * In case of an even number of values, it returns the right + * one from the two values in the middle, it doesn't take the average. + * + * The idea is that we want a value that exists, not something made up, + * and we prefer the higher value, in case this is a progression where + * picking the lower one would mean we'd be left behind. + */ + def median[T: Order](xs: NonEmptySeq[T]): T = + xs.sorted.getUnsafe(xs.size.toInt / 2) + + /** The final status coupled with the federation members that can serve the data. */ + case class FederationStatus[A <: Agreement]( + status: Status[A], + sources: NonEmptyVector[A#PKey] + ) +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/ConsensusEvent.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/ConsensusEvent.scala new file mode 100644 index 00000000..0d97af44 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/ConsensusEvent.scala @@ -0,0 +1,73 @@ +package io.iohk.metronome.hotstuff.service.tracing + +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + Event, + ProtocolError +} +import io.iohk.metronome.hotstuff.consensus.basic.QuorumCertificate +import io.iohk.metronome.hotstuff.service.ConsensusService.MessageCounter +import io.iohk.metronome.hotstuff.service.Status + +sealed trait ConsensusEvent[+A <: Agreement] + +object ConsensusEvent { + + /** The round ended without having reached decision. */ + case class Timeout( + viewNumber: ViewNumber, + messageCounter: MessageCounter + ) extends ConsensusEvent[Nothing] + + /** A full view synchronization was requested after timing out without any in-sync messages. */ + case class ViewSync( + viewNumber: ViewNumber + ) extends ConsensusEvent[Nothing] + + /** Adopting the view of the federation after a sync. */ + case class AdoptView[A <: Agreement]( + status: Status[A] + ) extends ConsensusEvent[A] + + /** The state advanced to a new view. */ + case class NewView(viewNumber: ViewNumber) extends ConsensusEvent[Nothing] + + /** Quorum over some block. */ + case class Quorum[A <: Agreement](quorumCertificate: QuorumCertificate[A]) + extends ConsensusEvent[A] + + /** A formally valid message was received from an earlier view number. */ + case class FromPast[A <: Agreement](message: Event.MessageReceived[A]) + extends ConsensusEvent[A] + + /** A formally valid message was received from a future view number. */ + case class FromFuture[A <: Agreement](message: Event.MessageReceived[A]) + extends ConsensusEvent[A] + + /** An event that arrived too early but got stashed and will be redelivered. */ + case class Stashed[A <: Agreement]( + error: ProtocolError.TooEarly[A] + ) extends ConsensusEvent[A] + + /** A rejected event. */ + case class Rejected[A <: Agreement]( + error: ProtocolError[A] + ) extends ConsensusEvent[A] + + /** A block has been removed from storage by the time it was to be executed. */ + case class ExecutionSkipped[A <: Agreement]( + blockHash: A#Hash + ) extends ConsensusEvent[A] + + /** A block has been executed. */ + case class BlockExecuted[A <: Agreement]( + blockHash: A#Hash + ) extends ConsensusEvent[A] + + /** An unexpected error in one of the background tasks. */ + case class Error( + message: String, + error: Throwable + ) extends ConsensusEvent[Nothing] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/ConsensusTracers.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/ConsensusTracers.scala new file mode 100644 index 00000000..2aaf1faa --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/ConsensusTracers.scala @@ -0,0 +1,52 @@ +package io.iohk.metronome.hotstuff.service.tracing + +import cats.implicits._ +import io.iohk.metronome.tracer.Tracer +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Agreement, + Event, + ProtocolError, + QuorumCertificate +} +import io.iohk.metronome.hotstuff.service.ConsensusService.MessageCounter +import io.iohk.metronome.hotstuff.service.Status + +case class ConsensusTracers[F[_], A <: Agreement]( + timeout: Tracer[F, (ViewNumber, MessageCounter)], + viewSync: Tracer[F, ViewNumber], + adoptView: Tracer[F, Status[A]], + newView: Tracer[F, ViewNumber], + quorum: Tracer[F, QuorumCertificate[A]], + fromPast: Tracer[F, Event.MessageReceived[A]], + fromFuture: Tracer[F, Event.MessageReceived[A]], + stashed: Tracer[F, ProtocolError.TooEarly[A]], + rejected: Tracer[F, ProtocolError[A]], + executionSkipped: Tracer[F, A#Hash], + blockExecuted: Tracer[F, A#Hash], + error: Tracer[F, (String, Throwable)] +) + +object ConsensusTracers { + import ConsensusEvent._ + + def apply[F[_], A <: Agreement]( + tracer: Tracer[F, ConsensusEvent[A]] + ): ConsensusTracers[F, A] = + ConsensusTracers[F, A]( + timeout = tracer.contramap[(ViewNumber, MessageCounter)]( + (Timeout.apply _).tupled + ), + viewSync = tracer.contramap[ViewNumber](ViewSync(_)), + adoptView = tracer.contramap[Status[A]](AdoptView(_)), + newView = tracer.contramap[ViewNumber](NewView(_)), + quorum = tracer.contramap[QuorumCertificate[A]](Quorum(_)), + fromPast = tracer.contramap[Event.MessageReceived[A]](FromPast(_)), + fromFuture = tracer.contramap[Event.MessageReceived[A]](FromFuture(_)), + stashed = tracer.contramap[ProtocolError.TooEarly[A]](Stashed(_)), + rejected = tracer.contramap[ProtocolError[A]](Rejected(_)), + executionSkipped = tracer.contramap[A#Hash](ExecutionSkipped(_)), + blockExecuted = tracer.contramap[A#Hash](BlockExecuted(_)), + error = tracer.contramap[(String, Throwable)]((Error.apply _).tupled) + ) +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/SyncEvent.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/SyncEvent.scala new file mode 100644 index 00000000..fc822a30 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/SyncEvent.scala @@ -0,0 +1,49 @@ +package io.iohk.metronome.hotstuff.service.tracing + +import io.iohk.metronome.core.Validated +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, ProtocolError} +import io.iohk.metronome.hotstuff.service.messages.SyncMessage +import io.iohk.metronome.hotstuff.service.Status +import io.iohk.metronome.hotstuff.consensus.basic.ProtocolError + +sealed trait SyncEvent[+A <: Agreement] + +object SyncEvent { + + /** A federation member is sending us so many requests that its work queue is full. */ + case class QueueFull[A <: Agreement]( + sender: A#PKey + ) extends SyncEvent[A] + + /** A request we sent couldn't be matched with a response in time. */ + case class RequestTimeout[A <: Agreement]( + recipient: A#PKey, + request: SyncMessage[A] with SyncMessage.Request + ) extends SyncEvent[A] + + /** A response was ignored either because the request ID didn't match, or it already timed out, + * or the response type didn't match the expected one based on the request. + */ + case class ResponseIgnored[A <: Agreement]( + sender: A#PKey, + response: SyncMessage[A] with SyncMessage.Response, + maybeError: Option[Throwable] + ) extends SyncEvent[A] + + /** Performed a poll for `Status` across the federation. + * Only contains results for federation members that responded within the timeout. + */ + case class StatusPoll[A <: Agreement]( + statuses: Map[A#PKey, Validated[Status[A]]] + ) extends SyncEvent[A] + + /** A federation members sent a `Status` with invalid content. */ + case class InvalidStatus[A <: Agreement]( + status: Status[A], + error: ProtocolError.InvalidQuorumCertificate[A], + hint: String + ) extends SyncEvent[A] + + /** An unexpected error in one of the background tasks. */ + case class Error(error: Throwable) extends SyncEvent[Nothing] +} diff --git a/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/SyncTracers.scala b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/SyncTracers.scala new file mode 100644 index 00000000..f480a448 --- /dev/null +++ b/metronome/hotstuff/service/src/io/iohk/metronome/hotstuff/service/tracing/SyncTracers.scala @@ -0,0 +1,49 @@ +package io.iohk.metronome.hotstuff.service.tracing + +import cats.implicits._ +import io.iohk.metronome.core.Validated +import io.iohk.metronome.tracer.Tracer +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, ProtocolError} +import io.iohk.metronome.hotstuff.service.messages.SyncMessage +import io.iohk.metronome.hotstuff.service.Status + +case class SyncTracers[F[_], A <: Agreement]( + queueFull: Tracer[F, A#PKey], + requestTimeout: Tracer[F, SyncTracers.Request[A]], + responseIgnored: Tracer[F, SyncTracers.Response[A]], + statusPoll: Tracer[F, SyncTracers.Statuses[A]], + invalidStatus: Tracer[F, SyncTracers.StatusError[A]], + error: Tracer[F, Throwable] +) + +object SyncTracers { + import SyncEvent._ + + type Request[A <: Agreement] = + (A#PKey, SyncMessage[A] with SyncMessage.Request) + + type Response[A <: Agreement] = + (A#PKey, SyncMessage[A] with SyncMessage.Response, Option[Throwable]) + + type Statuses[A <: Agreement] = + Map[A#PKey, Validated[Status[A]]] + + type StatusError[A <: Agreement] = + (Status[A], ProtocolError.InvalidQuorumCertificate[A], String) + + def apply[F[_], A <: Agreement]( + tracer: Tracer[F, SyncEvent[A]] + ): SyncTracers[F, A] = + SyncTracers[F, A]( + queueFull = tracer.contramap[A#PKey](QueueFull(_)), + requestTimeout = tracer + .contramap[Request[A]]((RequestTimeout.apply[A] _).tupled), + responseIgnored = tracer + .contramap[Response[A]]((ResponseIgnored.apply[A] _).tupled), + statusPoll = tracer + .contramap[Statuses[A]](StatusPoll(_)), + invalidStatus = + tracer.contramap[StatusError[A]]((InvalidStatus.apply[A] _).tupled), + error = tracer.contramap[Throwable](Error(_)) + ) +} diff --git a/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/MessageStashSpec.scala b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/MessageStashSpec.scala new file mode 100644 index 00000000..1ef30214 --- /dev/null +++ b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/MessageStashSpec.scala @@ -0,0 +1,111 @@ +package io.iohk.metronome.hotstuff.service + +import io.iohk.metronome.hotstuff.consensus.basic.Agreement +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import io.iohk.metronome.hotstuff.consensus.basic.{ + ProtocolError, + Event, + Message, + Phase, + QuorumCertificate +} +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.crypto.GroupSignature + +class MessageStashSpec extends AnyFlatSpec with Matchers { + import ConsensusService.MessageStash + + object TestAgreement extends Agreement { + override type Block = Nothing + override type Hash = Int + override type PSig = Nothing + override type GSig = Int + override type PKey = String + override type SKey = Nothing + } + type TestAgreement = TestAgreement.type + + "MessageStash" should behave like { + + val emptyStash = MessageStash.empty[TestAgreement] + + val error = ProtocolError.TooEarly[TestAgreement]( + Event.MessageReceived[TestAgreement]( + "Alice", + Message.NewView( + ViewNumber(10), + QuorumCertificate[TestAgreement]( + Phase.Prepare, + ViewNumber(9), + 123, + GroupSignature(456) + ) + ) + ), + expectedInViewNumber = ViewNumber(11), + expectedInPhase = Phase.Prepare + ) + val errorSlotKey = (error.expectedInViewNumber, error.expectedInPhase) + + it should "stash errors" in { + emptyStash.slots shouldBe empty + + val stash = emptyStash.stash(error) + + stash.slots should contain key errorSlotKey + stash.slots(errorSlotKey) should contain key error.event.sender + stash.slots(errorSlotKey)(error.event.sender) shouldBe error.event.message + } + + it should "stash only the last message from a sender" in { + val error2 = error.copy(event = + error.event.copy(message = + Message.NewView( + ViewNumber(10), + QuorumCertificate[TestAgreement]( + Phase.Prepare, + ViewNumber(8), + 122, + GroupSignature(455) + ) + ) + ) + ) + val stash = emptyStash.stash(error).stash(error2) + + stash.slots(errorSlotKey)( + error.event.sender + ) shouldBe error2.event.message + } + + it should "unstash due errors" in { + val errors = List( + error, + error.copy( + expectedInPhase = Phase.PreCommit + ), + error.copy( + expectedInViewNumber = error.expectedInViewNumber.next + ), + error.copy( + expectedInViewNumber = error.expectedInViewNumber.next, + expectedInPhase = Phase.Commit + ), + error.copy( + expectedInViewNumber = error.expectedInViewNumber.next.next + ) + ) + + val stash0 = errors.foldLeft(emptyStash)(_ stash _) + + val (stash1, unstashed) = stash0.unstash( + errors(2).expectedInViewNumber, + errors(2).expectedInPhase + ) + + stash1.slots.keySet should have size 2 + unstashed should have size 3 + } + } +} diff --git a/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/execution/BlockExecutorProps.scala b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/execution/BlockExecutorProps.scala new file mode 100644 index 00000000..18d1af0e --- /dev/null +++ b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/execution/BlockExecutorProps.scala @@ -0,0 +1,393 @@ +package io.iohk.metronome.hotstuff.service.execution + +import cats.implicits._ +import cats.effect.Resource +import cats.effect.concurrent.{Ref, Semaphore} +import cats.data.{NonEmptyVector, NonEmptyList} +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + Effect, + QuorumCertificate, + Phase +} +import io.iohk.metronome.crypto.GroupSignature +import io.iohk.metronome.hotstuff.service.ApplicationService +import io.iohk.metronome.hotstuff.service.storage.ViewStateStorage +import io.iohk.metronome.hotstuff.service.storage.{ + BlockStorageProps, + ViewStateStorageCommands +} +import io.iohk.metronome.hotstuff.service.tracing.{ + ConsensusEvent, + ConsensusTracers +} +import io.iohk.metronome.storage.InMemoryKVStore +import io.iohk.metronome.tracer.Tracer +import monix.eval.Task +import monix.execution.Scheduler +import org.scalacheck.{Properties, Arbitrary, Gen} +import org.scalacheck.Prop, Prop.{forAll, propBoolean, all} +import scala.concurrent.duration._ + +object BlockExecutorProps extends Properties("BlockExecutor") { + import BlockStorageProps.{ + TestAgreement, + TestBlock, + TestBlockStorage, + TestKVStore, + Namespace, + genNonEmptyBlockTree + } + import ViewStateStorageCommands.neverUsedCodec + + case class TestResources( + blockExecutor: BlockExecutor[Task, Namespace, TestAgreement], + viewStateStorage: ViewStateStorage[Namespace, TestAgreement], + executionSemaphore: Semaphore[Task] + ) + + case class TestFixture( + blocks: List[TestBlock], + batches: Vector[Effect.ExecuteBlocks[TestAgreement]] + ) { + val storeRef = Ref.unsafe[Task, TestKVStore.Store] { + TestKVStore.build(blocks) + } + val eventsRef = + Ref.unsafe[Task, Vector[ConsensusEvent[TestAgreement]]](Vector.empty) + + val store = InMemoryKVStore[Task, Namespace](storeRef) + + implicit val storeRunner = store + + val eventTracer = + Tracer.instance[Task, ConsensusEvent[TestAgreement]] { event => + eventsRef.update(_ :+ event) + } + + implicit val consensusTracers = ConsensusTracers(eventTracer) + + val failNextRef = Ref.unsafe[Task, Boolean](false) + val isExecutingRef = Ref.unsafe[Task, Boolean](false) + + private def appService(semaphore: Semaphore[Task]) = + new ApplicationService[Task, TestAgreement] { + def createBlock( + highQC: QuorumCertificate[TestAgreement] + ): Task[Option[TestBlock]] = ??? + + def validateBlock(block: TestBlock): Task[Option[Boolean]] = ??? + + def syncState( + sources: NonEmptyVector[Int], + block: TestBlock + ): Task[Boolean] = + Task.pure(true) + + def executeBlock( + block: TestBlock, + commitQC: QuorumCertificate[TestAgreement], + commitPath: NonEmptyList[TestAgreement.Hash] + ): Task[Boolean] = + isExecutingRef + .set(true) + .bracket(_ => + semaphore.withPermit { + for { + fail <- failNextRef.modify(failNext => (false, failNext)) + _ <- Task + .raiseError(new RuntimeException("The application failed!")) + .whenA(fail) + } yield true + } + )(_ => isExecutingRef.set(false)) + + } + + val resources: Resource[Task, TestResources] = + for { + viewStateStorage <- Resource.liftF { + storeRunner.runReadWrite { + val genesisQC = QuorumCertificate[TestAgreement]( + phase = Phase.Commit, + viewNumber = ViewNumber(0), + blockHash = blocks.head.id, + signature = GroupSignature(()) + ) + val genesisBundle = ViewStateStorage.Bundle.fromGenesisQC(genesisQC) + + ViewStateStorage[Namespace, TestAgreement]( + "view-state", + genesisBundle + ) + } + } + semaphore <- Resource.liftF(Semaphore[Task](1)) + blockExecutor <- BlockExecutor[Task, Namespace, TestAgreement]( + appService(semaphore), + TestBlockStorage, + viewStateStorage + ) + } yield TestResources(blockExecutor, viewStateStorage, semaphore) + + val executedBlockHashes = + eventsRef.get + .map { events => + events.collect { case ConsensusEvent.BlockExecuted(blockHash) => + blockHash + } + } + + val lastBatchCommitedBlockHash = + batches.last.quorumCertificate.blockHash + + def awaitBlockExecution( + blockHash: TestAgreement.Hash + ): Task[Vector[TestAgreement.Hash]] = { + executedBlockHashes + .restartUntil { blockHashes => + blockHashes.contains(blockHash) + } + } + } + + object TestFixture { + implicit val arb: Arbitrary[TestFixture] = Arbitrary(gen()) + + /** Create a random number of tree extensions, with each extension + * covered by a batch that goes from its root to one of its leaves. + */ + def gen(minBatches: Int = 1, maxBatches: Int = 5): Gen[TestFixture] = { + def loop( + i: Int, + tree: List[TestBlock], + effects: Vector[Effect.ExecuteBlocks[TestAgreement]] + ): Gen[TestFixture] = { + if (i == 0) { + Gen.const(TestFixture(tree, effects)) + } else { + val extension = for { + viewNumber <- Gen.posNum[Int].map(ViewNumber(_)) + ancestor = tree.last + descendantTree <- genNonEmptyBlockTree(parent = ancestor) + descendant = descendantTree.last + commitQC = QuorumCertificate[TestAgreement]( + phase = Phase.Commit, + viewNumber = viewNumber, + blockHash = descendant.id, + signature = GroupSignature(()) + ) + effect = Effect.ExecuteBlocks[TestAgreement]( + lastExecutedBlockHash = ancestor.id, + quorumCertificate = commitQC + ) + } yield (tree ++ descendantTree, effects :+ effect) + + extension.flatMap { case (tree, effects) => + loop(i - 1, tree, effects) + } + } + } + + for { + prefixTree <- genNonEmptyBlockTree + i <- Gen.choose(minBatches, maxBatches) + fixture <- loop(i, prefixTree, Vector.empty) + } yield fixture + } + } + + def run(test: Task[Prop]): Prop = { + import Scheduler.Implicits.global + test.runSyncUnsafe(timeout = 5.seconds) + } + + property("executeBlocks - from root") = forAll { (fixture: TestFixture) => + run { + fixture.resources.use { res => + for { + _ <- fixture.batches.traverse(res.blockExecutor.enqueue) + + executedBlockHashes <- fixture.awaitBlockExecution( + fixture.lastBatchCommitedBlockHash + ) + + // The genesis was the only block we marked as executed. + pathFromRoot <- fixture.storeRunner.runReadOnly { + TestBlockStorage.getPathFromRoot(fixture.lastBatchCommitedBlockHash) + } + + } yield { + "executes from the root" |: executedBlockHashes == pathFromRoot.tail + } + } + } + } + + property("executeBlocks - from last") = forAll { (fixture: TestFixture) => + run { + fixture.resources.use { res => + val lastBatch = fixture.batches.last + val lastExecutedBlockHash = lastBatch.lastExecutedBlockHash + for { + _ <- fixture.storeRunner.runReadWrite { + res.viewStateStorage.setLastExecutedBlockHash(lastExecutedBlockHash) + } + _ <- res.blockExecutor.enqueue(lastBatch) + + executedBlockHashes <- fixture.awaitBlockExecution( + fixture.lastBatchCommitedBlockHash + ) + + pathFromLast <- fixture.storeRunner.runReadOnly { + TestBlockStorage.getPathFromAncestor( + lastExecutedBlockHash, + fixture.lastBatchCommitedBlockHash + ) + } + + } yield { + "executes from the last" |: executedBlockHashes == pathFromLast.tail + } + } + } + } + + property("executeBlocks - from pruned") = forAll { (fixture: TestFixture) => + run { + fixture.resources.use { res => + val lastBatch = fixture.batches.last + val lastExecutedBlockHash = lastBatch.lastExecutedBlockHash + for { + _ <- fixture.storeRunner.runReadWrite { + TestBlockStorage.pruneNonDescendants(lastExecutedBlockHash) + } + _ <- res.blockExecutor.enqueue(lastBatch) + + executedBlockHashes <- fixture.awaitBlockExecution( + fixture.lastBatchCommitedBlockHash + ) + + // The last executed block should be the new root. + pathFromRoot <- fixture.storeRunner.runReadOnly { + TestBlockStorage.getPathFromRoot(fixture.lastBatchCommitedBlockHash) + } + } yield { + all( + "new root" |: pathFromRoot.head == lastExecutedBlockHash, + "executes from the last" |: executedBlockHashes == pathFromRoot.tail + ) + } + } + } + } + + property("executeBlocks - from failed") = + // Only the next commit batch triggers re-execution, so we need at least 2. + forAll(TestFixture.gen(minBatches = 2)) { (fixture: TestFixture) => + run { + fixture.resources.use { res => + for { + _ <- fixture.failNextRef.set(true) + _ <- fixture.batches.traverse(res.blockExecutor.enqueue) + _ <- fixture.awaitBlockExecution(fixture.lastBatchCommitedBlockHash) + events <- fixture.eventsRef.get + } yield { + 1 === events.count { + case _: ConsensusEvent.Error => true + case _ => false + } + } + } + } + } + + property("executeBlocks - skipped") = + // Using 4 batches so the 2nd batch definitely doesn't start with the last executed block, + // which will be the root initially, and it's distinct from the last batch as well. + forAll(TestFixture.gen(minBatches = 4)) { (fixture: TestFixture) => + run { + fixture.resources.use { res => + val execBatch = fixture.batches.tail.head + val lastBatch = fixture.batches.last + for { + // Make the execution wait until we update the view state. + _ <- res.executionSemaphore.acquire + _ <- res.blockExecutor.enqueue(execBatch) + + // Wait until the execution has started before updating the view state + // so that all the blocks are definitely enqueued already. + _ <- fixture.isExecutingRef.get.restartUntil(identity) + + // Now skip ahead, like if we did a fast-forward sync. + _ <- fixture.storeRunner.runReadWrite { + res.viewStateStorage.setLastExecutedBlockHash( + lastBatch.lastExecutedBlockHash + ) + } + _ <- res.executionSemaphore.release + + // Easiest indicator of everything being finished is to execute the last batch. + _ <- res.blockExecutor.enqueue(lastBatch) + _ <- fixture.awaitBlockExecution( + lastBatch.quorumCertificate.blockHash + ) + + events <- fixture.eventsRef.get + executedBlockHashes = events.collect { + case ConsensusEvent.BlockExecuted(blockHash) => blockHash + } + skippedBlockHashes = events.collect { + case ConsensusEvent.ExecutionSkipped(blockHash) => blockHash + } + + path <- fixture.storeRunner.runReadOnly { + TestBlockStorage.getPathFromRoot( + execBatch.quorumCertificate.blockHash + ) + } + } yield { + all( + // The first block after the root will be executed, only then do we skip the rest. + "executes the first block" |: executedBlockHashes.head == path.tail.head, + "skips rest of the blocks" |: skippedBlockHashes == path.drop(2) + ) + } + } + } + } + + property("syncState") = forAll { (fixture: TestFixture) => + run { + fixture.resources.use { res => + val lastBatch = fixture.batches.last + for { + block <- fixture.storeRunner.runReadOnly { + TestBlockStorage.get(lastBatch.lastExecutedBlockHash).map(_.get) + } + _ <- res.blockExecutor.syncState( + sources = NonEmptyVector.one(0), + block = block + ) + _ <- fixture.batches.traverse(res.blockExecutor.enqueue) + + executedBlockHashes <- fixture.awaitBlockExecution( + fixture.lastBatchCommitedBlockHash + ) + + // The last executed block should be the new root after pruning away old blocks. + pathFromRoot <- fixture.storeRunner.runReadOnly { + TestBlockStorage.getPathFromRoot( + fixture.lastBatchCommitedBlockHash + ) + } + } yield { + all( + "prunes to the fast forwared block" |: pathFromRoot.head == lastBatch.lastExecutedBlockHash, + "executes from the fast forwarded block" |: executedBlockHashes == pathFromRoot.tail + ) + } + } + } + } +} diff --git a/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/storage/BlockStorageProps.scala b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/storage/BlockStorageProps.scala new file mode 100644 index 00000000..993755d3 --- /dev/null +++ b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/storage/BlockStorageProps.scala @@ -0,0 +1,408 @@ +package io.iohk.metronome.hotstuff.service.storage + +import cats.implicits._ +import io.iohk.metronome.storage.{KVCollection, KVStoreState, KVTree} +import io.iohk.metronome.hotstuff.consensus.basic.{Agreement, Block => BlockOps} +import org.scalacheck._ +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Prop.{all, forAll, propBoolean} +import scodec.codecs.implicits._ +import scodec.Codec +import scala.util.Random + +object BlockStorageProps extends Properties("BlockStorage") { + + case class TestBlock(id: String, parentId: String, height: Long) { + def isGenesis = parentId.isEmpty + } + + object TestAgreement extends Agreement { + type Block = TestBlock + type Hash = String + type PSig = Nothing + type GSig = Unit + type PKey = Int + type SKey = Nothing + + implicit val block = new BlockOps[TestAgreement] { + override def blockHash(b: TestBlock) = b.id + override def parentBlockHash(b: TestBlock) = b.parentId + override def height(b: Block): Long = b.height + override def isValid(b: Block) = true + } + } + type TestAgreement = TestAgreement.type + type Hash = TestAgreement.Hash + + implicit def `Codec[Set[T]]`[T: Codec] = + implicitly[Codec[List[T]]].xmap[Set[T]](_.toSet, _.toList) + + type Namespace = String + object Namespace { + val Blocks = "blocks" + val BlockMetas = "block-metas" + val BlockToChildren = "block-to-children" + } + + object TestBlockStorage + extends BlockStorage[Namespace, TestAgreement]( + new KVCollection[Namespace, Hash, TestBlock](Namespace.Blocks), + new KVCollection[Namespace, Hash, KVTree.NodeMeta[Hash]]( + Namespace.BlockMetas + ), + new KVCollection[Namespace, Hash, Set[Hash]](Namespace.BlockToChildren) + ) + + object TestKVStore extends KVStoreState[Namespace] { + def build(tree: List[TestBlock]): Store = { + val insert = tree.map(TestBlockStorage.put).sequence + compile(insert).runS(Map.empty).value + } + } + + implicit class TestStoreOps(store: TestKVStore.Store) { + def putBlock(block: TestBlock) = + TestKVStore.compile(TestBlockStorage.put(block)).runS(store).value + + def containsBlock(blockHash: Hash) = + TestKVStore + .compile(TestBlockStorage.contains(blockHash)) + .run(store) + + def getBlock(blockHash: Hash) = + TestKVStore + .compile(TestBlockStorage.get(blockHash)) + .run(store) + + def deleteBlock(blockHash: Hash) = + TestKVStore + .compile(TestBlockStorage.delete(blockHash)) + .run(store) + .value + + def getPathFromRoot(blockHash: Hash) = + TestKVStore + .compile(TestBlockStorage.getPathFromRoot(blockHash)) + .run(store) + + def getPathFromAncestor( + ancestorBlockHash: Hash, + descendantBlockHash: Hash + ) = + TestKVStore + .compile( + TestBlockStorage + .getPathFromAncestor( + ancestorBlockHash, + descendantBlockHash + ) + ) + .run(store) + + def getDescendants(blockHash: Hash) = + TestKVStore + .compile(TestBlockStorage.getDescendants(blockHash)) + .run(store) + + def pruneNonDescendants(blockHash: Hash) = + TestKVStore + .compile(TestBlockStorage.pruneNonDescendants(blockHash)) + .run(store) + .value + + def purgeTree(blockHash: Hash, keep: Option[Hash]) = + TestKVStore + .compile(TestBlockStorage.purgeTree(blockHash, keep)) + .run(store) + .value + } + + def genBlockId: Gen[Hash] = + Gen.uuid.map(_.toString) + + /** Generate a block with a given parent, using the next available ID. */ + def genBlock(parent: TestBlock): Gen[TestBlock] = + genBlockId.map { uuid => + TestBlock(uuid, parentId = parent.id, height = parent.height + 1) + } + + def genBlock: Gen[TestBlock] = + for { + id <- genBlockId + parentId <- genBlockId + height <- Gen.posNum[Long] + } yield TestBlock(id, parentId, height) + + // A block we can pass as parent to tree generators so the first block is a + // genesis block with height = 0 and parentId = "". + val preGenesisParent = TestBlock(id = "", parentId = "", height = -1) + + /** Generate a (possibly empty) block tree. */ + def genBlockTree(parent: TestBlock): Gen[List[TestBlock]] = + for { + childCount <- Gen.frequency( + 3 -> 0, + 5 -> 1, + 2 -> 2 + ) + children <- Gen.listOfN( + childCount, { + for { + block <- genBlock(parent) + tree <- genBlockTree(block) + } yield block +: tree + } + ) + } yield children.flatten + + def genBlockTree: Gen[List[TestBlock]] = + genBlockTree(preGenesisParent) + + def genNonEmptyBlockTree(parent: TestBlock): Gen[List[TestBlock]] = for { + child <- genBlock(parent) + tree <- genBlockTree(child) + } yield child +: tree + + def genNonEmptyBlockTree: Gen[List[TestBlock]] = + genNonEmptyBlockTree(preGenesisParent) + + case class TestData( + tree: List[TestBlock], + store: TestKVStore.Store + ) + object TestData { + def apply(tree: List[TestBlock]): TestData = { + val store = TestKVStore.build(tree) + TestData(tree, store) + } + } + + def genExisting = for { + tree <- genNonEmptyBlockTree + existing <- Gen.oneOf(tree) + data = TestData(tree) + } yield (data, existing) + + def genNonExisting = for { + tree <- genBlockTree + nonExisting <- genBlock + data = TestData(tree) + } yield (data, nonExisting) + + def genSubTree = for { + tree <- genNonEmptyBlockTree + leaf = tree.last + subTree <- genBlockTree(parent = leaf) + data = TestData(tree ++ subTree) + } yield (data, leaf, subTree) + + property("put") = forAll(genNonExisting) { case (data, block) => + val s = data.store.putBlock(block) + s(Namespace.Blocks)(block.id) == block + s(Namespace.BlockMetas)(block.id) + .asInstanceOf[KVTree.NodeMeta[Hash]] + .parentKey == block.parentId + } + + property("put unordered") = forAll { + for { + ordered <- genNonEmptyBlockTree + seed <- arbitrary[Int] + unordered = new Random(seed).shuffle(ordered) + } yield (ordered, unordered) + } { case (ordered, unordered) => + val orderedStore = TestKVStore.build(ordered) + val unorderedStore = TestKVStore.build(unordered) + orderedStore == unorderedStore + } + + property("contains existing") = forAll(genExisting) { case (data, existing) => + data.store.containsBlock(existing.id) + } + + property("contains non-existing") = forAll(genNonExisting) { + case (data, nonExisting) => + !data.store.containsBlock(nonExisting.id) + } + + property("get existing") = forAll(genExisting) { case (data, existing) => + data.store.getBlock(existing.id).contains(existing) + } + + property("get non-existing") = forAll(genNonExisting) { + case (data, nonExisting) => + data.store.getBlock(nonExisting.id).isEmpty + } + + property("delete existing") = forAll(genExisting) { case (data, existing) => + val childCount = data.tree.count(_.parentId == existing.id) + val noParent = !data.tree.exists(_.id == existing.parentId) + val (s, ok) = data.store.deleteBlock(existing.id) + all( + "deleted" |: s.containsBlock(existing.id) == !ok, + "ok" |: ok && (childCount == 0 || childCount == 1 && noParent) || !ok + ) + } + + property("delete non-existing") = forAll(genNonExisting) { + case (data, nonExisting) => + data.store.deleteBlock(nonExisting.id)._2 == true + } + + property("reinsert one") = forAll(genExisting) { case (data, existing) => + val (deleted, _) = data.store.deleteBlock(existing.id) + val inserted = deleted.putBlock(existing) + // The existing child relationships should not be lost. + inserted == data.store + } + + property("getPathFromRoot existing") = forAll(genExisting) { + case (data, existing) => + val path = data.store.getPathFromRoot(existing.id) + all( + "nonEmpty" |: path.nonEmpty, + "head" |: path.headOption.contains(data.tree.head.id), + "last" |: path.lastOption.contains(existing.id) + ) + } + + property("getPathFromRoot non-existing") = forAll(genNonExisting) { + case (data, nonExisting) => + data.store.getPathFromRoot(nonExisting.id).isEmpty + } + + property("getPathFromAncestor") = forAll( + for { + prefix <- genNonEmptyBlockTree + ancestor = prefix.last + postfix <- genNonEmptyBlockTree(ancestor) + descendant <- Gen.oneOf(postfix) + data = TestData(prefix ++ postfix) + nonExisting <- genBlock + } yield (data, ancestor, descendant, nonExisting) + ) { case (data, ancestor, descendant, nonExisting) => + def getPath(a: TestBlock, d: TestBlock) = + data.store.getPathFromAncestor(a.id, d.id) + + def pathExists(a: TestBlock, d: TestBlock) = { + val path = getPath(a, d) + path.nonEmpty && + path.distinct.size == path.size && + path.head == a.id && + path.last == d.id && + (path.init zip path.tail).forall { case (parentHash, childHash) => + data.store.getBlock(childHash).get.parentId == parentHash + } + } + + def pathNotExists(a: TestBlock, d: TestBlock) = + getPath(a, d).isEmpty + + all( + "fromAtoD" |: pathExists(ancestor, descendant), + "fromDtoA" |: pathNotExists(descendant, ancestor), + "fromAtoA" |: pathExists(ancestor, ancestor), + "fromDtoD" |: pathExists(descendant, descendant), + "fromAtoN" |: pathNotExists(ancestor, nonExisting), + "fromNtoD" |: pathNotExists(nonExisting, descendant) + ) + } + + property("getDescendants existing") = forAll(genSubTree) { + case (data, block, subTree) => + val ds = data.store.getDescendants(block.id) + val dss = ds.toSet + all( + "nonEmpty" |: ds.nonEmpty, + "last" |: ds.lastOption.contains(block.id), + "size" |: ds.size == subTree.size + 1, + "subtree" |: subTree.forall(block => dss.contains(block.id)) + ) + } + + property("getDescendants non-existing") = forAll(genNonExisting) { + case (data, nonExisting) => + data.store.getDescendants(nonExisting.id).isEmpty + } + + property("getDescendants delete") = forAll(genSubTree) { + case (data, block, _) => + val ds = data.store.getDescendants(block.id) + + val (deleted, ok) = ds.foldLeft((data.store, true)) { + case ((store, oks), blockHash) => + val (deleted, ok) = store.deleteBlock(blockHash) + (deleted, oks && ok) + } + + val prefixTree = data.tree.takeWhile(_ != block) + val prefixStore = TestKVStore.build(prefixTree) + + all( + "ok" |: ok, + "not contains deleted" |: + ds.forall(!deleted.containsBlock(_)), + "contains non deleted" |: + prefixTree.map(_.id).forall(deleted.containsBlock(_)), + "same as a rebuild" |: + prefixStore == deleted + ) + } + + property("pruneNonDescendants existing") = forAll(genSubTree) { + case (data, block, subTree) => + val (s, ps) = data.store.pruneNonDescendants(block.id) + val pss = ps.toSet + val descendants = subTree.map(_.id).toSet + val nonDescendants = + data.tree.map(_.id).filterNot(descendants).filterNot(_ == block.id) + all( + "size" |: ps.size == nonDescendants.size, + "pruned" |: nonDescendants.forall(pss), + "deleted" |: nonDescendants.forall(!s.containsBlock(_)), + "kept-block" |: s.containsBlock(block.id), + "kept-descendants" |: descendants.forall(s.containsBlock(_)) + ) + } + + property("pruneNonDescendants non-existing") = forAll(genNonExisting) { + case (data, nonExisting) => + data.store.pruneNonDescendants(nonExisting.id)._2.isEmpty + } + + property("purgeTree keep block") = forAll( + for { + (data, keepBlock, subTree) <- genSubTree + refBlock <- Gen.oneOf(data.tree) + } yield (data, refBlock, keepBlock, subTree) + ) { case (data, refBlock, keepBlock, subTree) => + val (s, ps) = data.store.purgeTree( + blockHash = refBlock.id, + keep = Some(keepBlock.id) + ) + val pss = ps.toSet + val descendants = subTree.map(_.id).toSet + val nonDescendants = + data.tree.map(_.id).filterNot(descendants).filterNot(_ == keepBlock.id) + all( + "size" |: ps.size == nonDescendants.size, + "pruned" |: nonDescendants.forall(pss), + "deleted" |: nonDescendants.forall(!s.containsBlock(_)), + "kept-block" |: s.containsBlock(keepBlock.id), + "kept-descendants" |: descendants.forall(s.containsBlock(_)) + ) + } + + property("purgeTree keep nothing") = forAll(genSubTree) { + case (data, block, _) => + val (s, ps) = data.store.purgeTree( + blockHash = block.id, + keep = None + ) + val pss = ps.toSet + all( + "pruned all" |: pss.size == data.tree.size, + "kept nothing" |: s.isEmpty + ) + } +} diff --git a/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/storage/ViewStateStorageProps.scala b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/storage/ViewStateStorageProps.scala new file mode 100644 index 00000000..c95dc5c9 --- /dev/null +++ b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/storage/ViewStateStorageProps.scala @@ -0,0 +1,218 @@ +package io.iohk.metronome.hotstuff.service.storage + +import io.iohk.metronome.hotstuff.consensus.ViewNumber +import io.iohk.metronome.hotstuff.consensus.basic.{ + QuorumCertificate, + Phase, + Agreement +} +import scala.annotation.nowarn +import io.iohk.metronome.crypto.GroupSignature +import io.iohk.metronome.storage.{KVStore, KVStoreRead, KVStoreState} +import org.scalacheck.{Gen, Prop, Properties} +import org.scalacheck.commands.Commands +import org.scalacheck.Arbitrary.arbitrary +import scala.util.Try +import scodec.bits.BitVector +import scodec.Codec +import scala.util.Success + +object ViewStateStorageProps extends Properties("ViewStateStorage") { + property("commands") = ViewStateStorageCommands.property() +} + +object ViewStateStorageCommands extends Commands { + object TestAgreement extends Agreement { + type Block = Nothing + type Hash = String + type PSig = Unit + type GSig = List[String] + type PKey = Nothing + type SKey = Nothing + } + type TestAgreement = TestAgreement.type + + type Namespace = String + + object TestKVStoreState extends KVStoreState[Namespace] + + type TestViewStateStorage = ViewStateStorage[Namespace, TestAgreement] + + class StorageWrapper( + viewStateStorage: TestViewStateStorage, + private var store: TestKVStoreState.Store + ) { + def getStore = store + + def write( + f: TestViewStateStorage => KVStore[Namespace, Unit] + ): Unit = { + store = TestKVStoreState.compile(f(viewStateStorage)).runS(store).value + } + + def read[A]( + f: TestViewStateStorage => KVStoreRead[Namespace, A] + ): A = { + TestKVStoreState.compile(f(viewStateStorage)).run(store) + } + } + + type State = ViewStateStorage.Bundle[TestAgreement] + type Sut = StorageWrapper + + val genesisState = ViewStateStorage.Bundle + .fromGenesisQC[TestAgreement] { + QuorumCertificate[TestAgreement]( + Phase.Prepare, + ViewNumber(1), + "", + GroupSignature(Nil) + ) + } + + /** The in-memory KVStoreState doesn't invoke the codecs. */ + implicit def neverUsedCodec[T] = + Codec[T]( + (_: T) => sys.error("Didn't expect to encode."), + (_: BitVector) => sys.error("Didn't expect to decode.") + ) + + @nowarn + override def canCreateNewSut( + newState: State, + initSuts: Traversable[State], + runningSuts: Traversable[Sut] + ): Boolean = true + + override def initialPreCondition(state: State): Boolean = + state == genesisState + + override def newSut(state: State): Sut = { + val init = TestKVStoreState.compile( + ViewStateStorage[Namespace, TestAgreement]("test-namespace", state) + ) + val (store, storage) = init.run(Map.empty).value + new StorageWrapper(storage, store) + } + + override def destroySut(sut: Sut): Unit = () + + override def genInitialState: Gen[State] = Gen.const(genesisState) + + override def genCommand(state: State): Gen[Command] = + Gen.oneOf( + genSetViewNumber(state), + genSetQuorumCertificate(state), + genSetLastExecutedBlockHash(state), + genGetBundle + ) + + def genSetViewNumber(state: State) = + for { + d <- Gen.posNum[Long] + vn = ViewNumber(state.viewNumber + d) + } yield SetViewNumberCommand(vn) + + def genSetQuorumCertificate(state: State) = + for { + p <- Gen.oneOf(Phase.Prepare, Phase.PreCommit, Phase.Commit) + h <- arbitrary[TestAgreement.Hash] + s <- arbitrary[TestAgreement.GSig] + qc = QuorumCertificate[TestAgreement]( + p, + state.viewNumber, + h, + GroupSignature(s) + ) + } yield SetQuorumCertificateCommand(qc) + + def genSetLastExecutedBlockHash(state: State) = + for { + h <- Gen.oneOf( + state.prepareQC.blockHash, + state.lockedQC.blockHash, + state.commitQC.blockHash + ) + } yield SetLastExecutedBlockHashCommand(h) + + def genSetRootBlockHash(state: State) = + for { + h <- Gen.oneOf( + state.prepareQC.blockHash, + state.lockedQC.blockHash, + state.commitQC.blockHash + ) + } yield SetRootBlockHashCommand(h) + + val genGetBundle = Gen.const(GetBundleCommand) + + case class SetViewNumberCommand(viewNumber: ViewNumber) extends UnitCommand { + override def run(sut: Sut): Result = + sut.write(_.setViewNumber(viewNumber)) + override def nextState(state: State): State = + state.copy(viewNumber = viewNumber) + override def preCondition(state: State): Boolean = + state.viewNumber < viewNumber + override def postCondition(state: State, success: Boolean): Prop = success + } + + case class SetQuorumCertificateCommand(qc: QuorumCertificate[TestAgreement]) + extends UnitCommand { + override def run(sut: Sut): Result = + sut.write(_.setQuorumCertificate(qc)) + + override def nextState(state: State): State = + qc.phase match { + case Phase.Prepare => state.copy(prepareQC = qc) + case Phase.PreCommit => state.copy(lockedQC = qc) + case Phase.Commit => state.copy(commitQC = qc) + } + + override def preCondition(state: State): Boolean = + state.viewNumber <= qc.viewNumber + + override def postCondition(state: State, success: Boolean): Prop = success + } + + case class SetLastExecutedBlockHashCommand(blockHash: TestAgreement.Hash) + extends UnitCommand { + override def run(sut: Sut): Result = + sut.write(_.setLastExecutedBlockHash(blockHash)) + + override def nextState(state: State): State = + state.copy(lastExecutedBlockHash = blockHash) + + override def preCondition(state: State): Boolean = + Set(state.prepareQC, state.lockedQC, state.commitQC) + .map(_.blockHash) + .contains(blockHash) + + override def postCondition(state: State, success: Boolean): Prop = success + } + + case class SetRootBlockHashCommand(blockHash: TestAgreement.Hash) + extends UnitCommand { + override def run(sut: Sut): Result = + sut.write(_.setRootBlockHash(blockHash)) + + override def nextState(state: State): State = + state.copy(rootBlockHash = blockHash) + + override def preCondition(state: State): Boolean = + Set(state.prepareQC, state.lockedQC, state.commitQC) + .map(_.blockHash) + .contains(blockHash) + + override def postCondition(state: State, success: Boolean): Prop = success + } + + case object GetBundleCommand extends Command { + type Result = ViewStateStorage.Bundle[TestAgreement] + + override def run(sut: Sut): Result = sut.read(_.getBundle) + override def nextState(state: State): State = state + override def preCondition(state: State): Boolean = true + override def postCondition(state: State, result: Try[Result]): Prop = + result == Success(state) + } +} diff --git a/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/sync/BlockSynchronizerProps.scala b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/sync/BlockSynchronizerProps.scala new file mode 100644 index 00000000..dd6596c2 --- /dev/null +++ b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/sync/BlockSynchronizerProps.scala @@ -0,0 +1,281 @@ +package io.iohk.metronome.hotstuff.service.sync + +import cats.implicits._ +import cats.data.NonEmptyVector +import cats.effect.concurrent.{Ref, Semaphore} +import io.iohk.metronome.crypto.GroupSignature +import io.iohk.metronome.hotstuff.consensus.{ + ViewNumber, + Federation, + LeaderSelection +} +import io.iohk.metronome.hotstuff.consensus.basic.{QuorumCertificate, Phase} +import io.iohk.metronome.hotstuff.service.storage.BlockStorageProps +import io.iohk.metronome.storage.InMemoryKVStore +import org.scalacheck.{Properties, Arbitrary, Gen, Prop}, Arbitrary.arbitrary +import org.scalacheck.Prop.{all, forAll, forAllNoShrink, propBoolean} +import monix.eval.Task +import monix.execution.schedulers.TestScheduler +import scala.util.Random +import scala.concurrent.duration._ + +object BlockSynchronizerProps extends Properties("BlockSynchronizer") { + import BlockStorageProps.{ + TestAgreement, + TestBlock, + TestBlockStorage, + TestKVStore, + Namespace, + genNonEmptyBlockTree + } + + case class Prob(value: Double) { + require(value >= 0 && value <= 1) + } + + // Insert the prefix three into "persistent" storage, + // then start multiple concurrent download processes + // from random federation members pointing at various + // nodes in the subtree. + // + // In the end all synced subtree elements should be + // persisted and the ephemeral storage left empty. + // At no point during the process should the persistent + // storage contain a forest. + case class TestFixture( + ancestorTree: List[TestBlock], + descendantTree: List[TestBlock], + requests: List[(TestAgreement.PKey, QuorumCertificate[TestAgreement])], + federation: Federation[TestAgreement.PKey], + random: Random, + timeoutProb: Prob, + corruptProb: Prob + ) { + val persistentRef = Ref.unsafe[Task, TestKVStore.Store] { + TestKVStore.build(ancestorTree) + } + val ephemeralRef = Ref.unsafe[Task, TestKVStore.Store](Map.empty) + + val persistentStore = InMemoryKVStore[Task, Namespace](persistentRef) + val inMemoryStore = InMemoryKVStore[Task, Namespace](ephemeralRef) + + val blockMap = (ancestorTree ++ descendantTree).map { block => + block.id -> block + }.toMap + + val downloadedRef = Ref.unsafe[Task, Set[TestAgreement.Hash]](Set.empty) + + def getBlock( + from: TestAgreement.PKey, + blockHash: TestAgreement.Hash + ): Task[Option[TestAgreement.Block]] = { + val timeout = 5000 + val delay = random.nextDouble() * 2900 + 100 + val isTimeout = random.nextDouble() < timeoutProb.value + val isCorrupt = random.nextDouble() < corruptProb.value + + if (isTimeout) { + Task.pure(None).delayResult(timeout.millis) + } else { + val block = blockMap(blockHash) + val result = if (isCorrupt) corrupt(block) else block + Task { + downloadedRef.update(_ + blockHash) + }.as(Some(result)).delayResult(delay.millis) + } + } + + implicit val storeRunner = persistentStore + + val synchronizer = new BlockSynchronizer[Task, Namespace, TestAgreement]( + publicKey = federation.publicKeys.head, + federation = federation, + blockStorage = TestBlockStorage, + getBlock = getBlock, + inMemoryStore = inMemoryStore, + semaphore = makeSemapshore() + ) + + private def makeSemapshore() = { + import monix.execution.Scheduler.Implicits.global + Semaphore[Task](1).runSyncUnsafe() + } + + def corrupt(block: TestBlock) = block.copy(id = "corrupt") + def isCorrupt(block: TestBlock) = block.id == "corrupt" + } + object TestFixture { + + implicit val arb: Arbitrary[TestFixture] = Arbitrary(gen()) + + def gen(timeoutProb: Prob = Prob(0.2), corruptProb: Prob = Prob(0.2)) = + for { + ancestorTree <- genNonEmptyBlockTree + leaf = ancestorTree.last + descendantTree <- genNonEmptyBlockTree(parent = leaf) + + federationSize <- Gen.choose(3, 10) + federationKeys = Range(0, federationSize).toVector + federation = Federation(federationKeys)(LeaderSelection.RoundRobin) + .getOrElse(sys.error("Can't create federation.")) + + existingPrepares <- Gen.someOf(ancestorTree) + newPrepares <- Gen.atLeastOne(descendantTree) + + prepares = (existingPrepares ++ newPrepares).toList + proposerKeys <- Gen.listOfN(prepares.size, Gen.oneOf(federationKeys)) + + requests = (prepares zip proposerKeys).zipWithIndex.map { + case ((parent, publicKey), idx) => + publicKey -> QuorumCertificate[TestAgreement]( + phase = Phase.Prepare, + viewNumber = ViewNumber(100L + idx), + blockHash = parent.id, + signature = GroupSignature(()) + ) + } + + random <- arbitrary[Int].map(seed => new Random(seed)) + + } yield TestFixture( + ancestorTree, + descendantTree, + requests, + federation, + random, + timeoutProb, + corruptProb + ) + } + + def simulate(duration: FiniteDuration)(test: Task[Prop]): Prop = { + implicit val scheduler = TestScheduler() + // Schedule the execution, using a Future so we can check the value. + val testFuture = test.runToFuture + // Simulate a time. + scheduler.tick(duration) + // Get the completed results. + testFuture.value.get.get + } + + property("sync - persist") = forAll { (fixture: TestFixture) => + val test = for { + fibers <- Task.traverse(fixture.requests) { case (publicKey, qc) => + fixture.synchronizer.sync(publicKey, qc).start + } + _ <- Task.traverse(fibers)(_.join) + downloaded <- fixture.downloadedRef.get + persistent <- fixture.persistentRef.get + ephemeral <- fixture.ephemeralRef.get + } yield { + all( + "ephemeral empty" |: ephemeral.isEmpty, + "persistent contains all" |: fixture.requests.forall { case (_, qc) => + persistent(Namespace.Blocks).contains(qc.blockHash) + }, + "all uncorrupted" |: persistent(Namespace.Blocks).forall { + case (blockHash, block: TestBlock) => + blockHash == block.id && !fixture.isCorrupt(block) + }, + "not download already persisted" |: fixture.ancestorTree.forall { + block => !downloaded(block.id) + } + ) + } + // Simulate a long time, which should be enough for all downloads to finish. + simulate(1.day)(test) + } + + property("sync - no forest") = forAll( + for { + fixture <- TestFixture.gen(timeoutProb = Prob(0)) + duration <- Gen.choose(1, fixture.requests.size).map(_ * 500.millis) + } yield (fixture, duration) + ) { case (fixture: TestFixture, duration: FiniteDuration) => + implicit val scheduler = TestScheduler() + + // Schedule the downloads in the background. + Task + .traverse(fixture.requests) { case (publicKey, qc) => + fixture.synchronizer.sync(publicKey, qc).startAndForget + } + .runAsyncAndForget + + // Simulate a some random time, which may or may not be enough to finish the downloads. + scheduler.tick(duration) + + // Check now that the persistent store has just one tree. + val test = for { + persistent <- fixture.persistentRef.get + } yield { + persistent(Namespace.Blocks).forall { case (_, block: TestBlock) => + // Either the block is the Genesis block with an empty parent ID, + // or it has a parent which has been inserted into the store. + block.parentId.isEmpty || + persistent(Namespace.Blocks).contains(block.parentId) + } + } + + val testFuture = test.runToFuture + + // Just simulate the immediate tasks. + scheduler.tick() + + testFuture.value.get.get + } + + property("getBlockFromQuorumCertificate") = forAllNoShrink( + for { + fixture <- TestFixture + .gen(timeoutProb = Prob(0), corruptProb = Prob(0)) + sources <- Gen.pick( + fixture.federation.quorumSize, + fixture.federation.publicKeys + ) + // The last request is definitely new. + qc = fixture.requests.last._2 + } yield (fixture, sources, qc) + ) { case (fixture, sources, qc) => + val test = for { + block <- fixture.synchronizer + .getBlockFromQuorumCertificate( + sources = NonEmptyVector.fromVectorUnsafe(sources.toVector), + quorumCertificate = qc + ) + .rethrow + persistent <- fixture.persistentRef.get + ephemeral <- fixture.ephemeralRef.get + } yield { + all( + "downloaded" |: block.id == qc.blockHash, + "not in ephemeral" |: ephemeral.isEmpty, + "not in persistent" |: + !persistent(Namespace.Blocks).contains(qc.blockHash) + ) + } + simulate(1.minute)(test) + } + + property("getBlockFromQuorumCertificate - timeout") = forAllNoShrink( + for { + fixture <- TestFixture.gen(timeoutProb = Prob(1)) + request = fixture.requests.last // Use one that isn't persisted yet. + } yield (fixture, request._1, request._2) + ) { case (fixture, source, qc) => + val test = for { + result <- fixture.synchronizer + .getBlockFromQuorumCertificate( + sources = NonEmptyVector.one(source), + quorumCertificate = qc + ) + } yield "fail with the right exception" |: { + result match { + case Left(_: BlockSynchronizer.DownloadFailedException[_]) => + true + case _ => + false + } + } + simulate(1.minute)(test) + } +} diff --git a/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/sync/ViewSynchronizerProps.scala b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/sync/ViewSynchronizerProps.scala new file mode 100644 index 00000000..c6bd409c --- /dev/null +++ b/metronome/hotstuff/service/test/src/io/iohk/metronome/hotstuff/service/sync/ViewSynchronizerProps.scala @@ -0,0 +1,380 @@ +package io.iohk.metronome.hotstuff.service.sync + +import cats.effect.concurrent.Ref +import io.iohk.metronome.hotstuff.consensus.{ + Federation, + LeaderSelection, + ViewNumber +} +import io.iohk.metronome.hotstuff.consensus.basic.{ + ProtocolStateCommands, + QuorumCertificate, + Phase, + VotingPhase, + Signing +} +import io.iohk.metronome.hotstuff.service.Status +import io.iohk.metronome.hotstuff.service.tracing.{SyncTracers, SyncEvent} +import io.iohk.metronome.tracer.Tracer +import monix.eval.Task +import monix.execution.schedulers.TestScheduler +import org.scalacheck.{Arbitrary, Properties, Gen}, Arbitrary.arbitrary +import org.scalacheck.Prop.{forAll, forAllNoShrink, propBoolean, all} +import scala.concurrent.duration._ +import java.util.concurrent.TimeoutException +import cats.data.NonEmptySeq +import scala.util.Random + +object ViewSynchronizerProps extends Properties("ViewSynchronizer") { + import ProtocolStateCommands.{ + TestAgreement, + mockSigning, + mockSigningKey, + genInitialState, + genHash + } + import ViewSynchronizer.FederationStatus + + /** Projected responses in each round from every federation member. */ + type Responses = Vector[Map[TestAgreement.PKey, TestResponse]] + + /** Generate N rounds worth of test responses, during which the synchronizer + * should find the first quorum, unless there's none in any of the rounds, + * in which case it will just keep getting timeouts forever. + */ + case class TestFixture( + rounds: Int, + federation: Federation[TestAgreement.PKey], + responses: Responses + ) { + val responseCounterRef = + Ref.unsafe[Task, Map[TestAgreement.PKey, Int]]( + federation.publicKeys.map(_ -> 0).toMap + ) + + val syncEventsRef = + Ref.unsafe[Task, Vector[SyncEvent[TestAgreement]]](Vector.empty) + + private val syncEventTracer = + Tracer.instance[Task, SyncEvent[TestAgreement]] { event => + syncEventsRef.update(_ :+ event) + } + + implicit val syncTracers: SyncTracers[Task, TestAgreement] = + SyncTracers(syncEventTracer) + + def getStatus( + publicKey: TestAgreement.PKey + ): Task[Option[Status[TestAgreement]]] = + for { + round <- responseCounterRef.modify { responseCounter => + val count = responseCounter(publicKey) + responseCounter.updated(publicKey, count + 1) -> count + } + result = + if (round >= responses.size) None + else + responses(round)(publicKey) match { + case TestResponse.Timeout => None + case TestResponse.InvalidStatus(status, _) => Some(status) + case TestResponse.ValidStatus(status) => Some(status) + } + } yield result + } + + object TestFixture { + implicit val leaderSelection = LeaderSelection.RoundRobin + + implicit val arb: Arbitrary[TestFixture] = Arbitrary { + for { + state <- genInitialState + federation = Federation(state.federation, state.f).getOrElse( + sys.error("Invalid federation.") + ) + byzantineCount <- Gen.choose(0, state.f) + byzantines = federation.publicKeys.take(byzantineCount).toSet + rounds <- Gen.posNum[Int] + genesisQC = state.newViewsHighQC + responses <- genResponses(rounds, federation, byzantines, genesisQC) + } yield TestFixture( + rounds, + federation, + responses + ) + } + } + + sealed trait TestResponse + object TestResponse { + case object Timeout extends TestResponse + case class ValidStatus(status: Status[TestAgreement]) extends TestResponse + case class InvalidStatus(status: Status[TestAgreement], reason: String) + extends TestResponse + } + + /** Generate a series of hypothetical responses projected from an idealized consensus process. */ + def genResponses( + rounds: Int, + federation: Federation[TestAgreement.PKey], + byzantines: Set[TestAgreement.PKey], + genesisQC: QuorumCertificate[TestAgreement] + ): Gen[Responses] = { + + def genQC( + viewNumber: ViewNumber, + phase: VotingPhase, + blockHash: TestAgreement.Hash + ) = + for { + quorumKeys <- Gen + .pick(federation.quorumSize, federation.publicKeys) + .map(_.toVector) + partialSigs = quorumKeys.map { publicKey => + val signingKey = mockSigningKey(publicKey) + Signing[TestAgreement].sign( + signingKey, + phase, + viewNumber, + blockHash + ) + } + groupSig = mockSigning.combine(partialSigs) + } yield QuorumCertificate[TestAgreement]( + phase, + viewNumber, + blockHash, + groupSig + ) + + /** Extend a Q.C. by building a new block on top of it. */ + def genPrepareQC(qc: QuorumCertificate[TestAgreement]) = + genHash.flatMap { blockHash => + genQC(qc.viewNumber.next, Phase.Prepare, blockHash) + } + + /** Extend a Q.C. by committing the block in it. */ + def genCommitQC(qc: QuorumCertificate[TestAgreement]) = + genQC(qc.viewNumber, Phase.Commit, qc.blockHash) + + def genInvalid(status: Status[TestAgreement]) = { + def delay(invalid: => (Status[TestAgreement], String)) = + Gen.delay(Gen.const(invalid)) + Gen.oneOf( + delay( + status.copy(viewNumber = + status.prepareQC.viewNumber.prev + ) -> "view number less than prepare" + ), + delay( + status.copy(prepareQC = + status.commitQC + ) -> "commit instead of prepare" + ), + delay( + status.copy(commitQC = + status.prepareQC + ) -> "prepare instead of commit" + ), + delay( + status.copy(commitQC = + status.commitQC.copy[TestAgreement](signature = + status.commitQC.signature + .copy(sig = status.commitQC.signature.sig.map(_ * 2)) + ) + ) -> "wrong commit signature" + ).filter(_._1.commitQC.viewNumber > 0) + ) + } + + def loop( + round: Int, + prepareQC: QuorumCertificate[TestAgreement], + commitQC: QuorumCertificate[TestAgreement], + accum: Responses + ): Gen[Responses] = + if (round == rounds) Gen.const(accum) + else { + val keepCommit = Gen.const(commitQC) + + def maybeCommit(qc: QuorumCertificate[TestAgreement]) = + if (qc.blockHash != commitQC.blockHash) genCommitQC(qc) + else keepCommit + + val genRound = for { + nextPrepareQC <- Gen.oneOf( + Gen.const(prepareQC), + genPrepareQC(prepareQC) + ) + nextCommitQC <- Gen.oneOf( + keepCommit, + maybeCommit(prepareQC), + maybeCommit(nextPrepareQC) + ) + status = Status(ViewNumber(round + 1), nextPrepareQC, nextCommitQC) + responses <- Gen.sequence[Vector[TestResponse], TestResponse] { + federation.publicKeys.map { publicKey => + if (byzantines.contains(publicKey)) { + Gen.frequency( + 3 -> Gen.const(TestResponse.Timeout), + 2 -> Gen.const(TestResponse.ValidStatus(status)), + 5 -> genInvalid(status).map( + (TestResponse.InvalidStatus.apply _).tupled + ) + ) + } else { + Gen.frequency( + 1 -> TestResponse.Timeout, + 4 -> TestResponse.ValidStatus(status) + ) + } + } + } + responseMap = (federation.publicKeys zip responses).toMap + } yield (nextPrepareQC, nextCommitQC, responseMap) + + genRound.flatMap { case (prepareQC, commitQC, responseMap) => + loop(round + 1, prepareQC, commitQC, accum :+ responseMap) + } + } + + loop( + 0, + genesisQC, + genesisQC.copy[TestAgreement](phase = Phase.Commit), + Vector.empty + ) + } + + property("sync") = forAll { (fixture: TestFixture) => + implicit val scheduler = TestScheduler() + import fixture.syncTracers + + val retryTimeout = 5.seconds + val syncTimeout = fixture.rounds * retryTimeout * 2 + val synchronizer = new ViewSynchronizer[Task, TestAgreement]( + federation = fixture.federation, + getStatus = fixture.getStatus, + retryTimeout = retryTimeout + ) + + val test = for { + status <- synchronizer.sync.timeout(syncTimeout).attempt + events <- fixture.syncEventsRef.get + + quorumSize = fixture.federation.quorumSize + + indexOfQuorum = fixture.responses.indexWhere { responseMap => + responseMap.values.collect { case TestResponse.ValidStatus(_) => + }.size >= quorumSize + } + hasQuorum = indexOfQuorum >= 0 + + invalidResponseCount = { + val responses = + if (hasQuorum) fixture.responses.take(indexOfQuorum + 1) + else fixture.responses + responses + .flatMap(_.values) + .collect { case _: TestResponse.InvalidStatus => + } + .size + } + + invalidEventCount = { + events.collect { case _: SyncEvent.InvalidStatus[_] => + }.size + } + + pollSizes = events.collect { case SyncEvent.StatusPoll(statuses) => + statuses.size + } + + responseCounter <- fixture.responseCounterRef.get + } yield { + val statusProps = status match { + case Right(FederationStatus(_, sources)) => + "status" |: all( + "quorum" |: hasQuorum, + "reports polls each round" |: + pollSizes.size == indexOfQuorum + 1, + "stop at the first quorum" |: + pollSizes.last >= quorumSize && + pollSizes.init.forall(_ < quorumSize), + "reports all invalid" |: + invalidEventCount == invalidResponseCount, + "returns sources" |: sources.toVector.size >= quorumSize + ) + + case Left(_: TimeoutException) => + "timeout" |: all( + "no quorum" |: !hasQuorum, + "empty polls" |: pollSizes.forall(_ < quorumSize), + "keeps polling" |: pollSizes.size >= fixture.rounds, + "reports all invalid" |: invalidEventCount == invalidResponseCount + ) + + case Left(ex) => + ex.getMessage |: false + } + + all( + statusProps, + "poll everyone in each round" |: + responseCounter.values.toList.distinct.size <= 2 // Some members can get an extra query, down to timing. + ) + } + + val testFuture = test.runToFuture + + scheduler.tick(syncTimeout) + + testFuture.value.get.get + } + + property("median") = forAllNoShrink( + for { + m <- arbitrary[Int].map(_.toLong) + l <- Gen.posNum[Int] + h <- Gen.oneOf(l, l - 1) + ls <- Gen.listOfN(l, Gen.posNum[Int].map(m - _)) + hs <- Gen.listOfN(h, Gen.posNum[Int].map(m + _)) + rnd <- arbitrary[Int].map(new Random(_)) + } yield (m, rnd.shuffle(ls ++ Seq(m) ++ hs)) + ) { case (m, xs) => + m == ViewSynchronizer.median(NonEmptySeq.fromSeqUnsafe(xs)) + } + + property("aggregateStatus") = forAllNoShrink( + for { + fixture <- arbitrary[TestFixture] + statuses = fixture.responses.flatMap(_.values).collect { + case TestResponse.ValidStatus(status) => status + } + if (statuses.nonEmpty) + rnd <- arbitrary[Int].map(new Random(_)) + } yield NonEmptySeq.fromSeqUnsafe(rnd.shuffle(statuses)) + ) { statuses => + val status = + ViewSynchronizer.aggregateStatus(statuses) + + val medianViewNumber = ViewSynchronizer.median(statuses.map(_.viewNumber)) + + val maxViewNumber = + statuses.map(_.viewNumber).toSeq.max + + val maxPrepareQC = + statuses.find(_.viewNumber == maxViewNumber).get.prepareQC + + val maxCommitQC = + statuses.find(_.viewNumber == maxViewNumber).get.commitQC + + all( + "viewNumber" |: + status.viewNumber == + (if (maxPrepareQC.viewNumber > medianViewNumber) maxPrepareQC.viewNumber + else medianViewNumber), + "prepareQC" |: status.prepareQC == maxPrepareQC, + s"commitQC ${status.commitQC} vs ${maxCommitQC}" |: status.commitQC == maxCommitQC + ) + } +} diff --git a/metronome/logging/src/io/iohk/metronome/logging/HybridLog.scala b/metronome/logging/src/io/iohk/metronome/logging/HybridLog.scala new file mode 100644 index 00000000..5ec61ecb --- /dev/null +++ b/metronome/logging/src/io/iohk/metronome/logging/HybridLog.scala @@ -0,0 +1,36 @@ +package io.iohk.metronome.logging + +import io.circe.JsonObject +import java.time.Instant +import scala.reflect.ClassTag + +/** Type class to transform instances of `T` to `HybridLogObject`. */ +trait HybridLog[T] { + def apply(value: T): HybridLogObject +} + +object HybridLog { + def apply[T](implicit ev: HybridLog[T]): HybridLog[T] = ev + + /** Create an instance of `HybridLog` for a type `T` by passing + * functions to transform instances of `T` to message and JSON. + */ + def instance[T: ClassTag]( + level: T => HybridLogObject.Level, + message: T => String, + event: T => JsonObject + ): HybridLog[T] = + new HybridLog[T] { + val source = implicitly[ClassTag[T]].runtimeClass.getName + + override def apply(value: T): HybridLogObject = { + HybridLogObject( + level = level(value), + timestamp = Instant.now(), + source = source, + message = message(value), + event = event(value) + ) + } + } +} diff --git a/metronome/logging/src/io/iohk/metronome/logging/HybridLogObject.scala b/metronome/logging/src/io/iohk/metronome/logging/HybridLogObject.scala new file mode 100644 index 00000000..aacb42ed --- /dev/null +++ b/metronome/logging/src/io/iohk/metronome/logging/HybridLogObject.scala @@ -0,0 +1,38 @@ +package io.iohk.metronome.logging + +import io.circe.JsonObject +import io.circe.syntax._ +import java.time.Instant +import cats.Show + +/** A hybrid log has a human readable message, which is intended to be static, + * and some key-value paramters that vary by events. + * + * See https://medium.com/unomaly/logging-wisdom-how-to-log-5a19145e35ec + */ +case class HybridLogObject( + timestamp: Instant, + source: String, + level: HybridLogObject.Level, + // Something captured about what emitted this event. + // Human readable message, which typically shouldn't + // change between events emitted at the same place. + message: String, + // Key-Value pairs that capture arbitrary data. + event: JsonObject +) +object HybridLogObject { + sealed trait Level + object Level { + case object Error extends Level + case object Warn extends Level + case object Info extends Level + case object Debug extends Level + case object Trace extends Level + } + + implicit val show: Show[HybridLogObject] = Show.show { + case HybridLogObject(t, s, l, m, e) => + s"$t ${l.toString.toUpperCase.padTo(5, ' ')} - $s: $m ${e.asJson.noSpaces}" + } +} diff --git a/metronome/logging/src/io/iohk/metronome/logging/InMemoryLogTracer.scala b/metronome/logging/src/io/iohk/metronome/logging/InMemoryLogTracer.scala new file mode 100644 index 00000000..d67dd7f2 --- /dev/null +++ b/metronome/logging/src/io/iohk/metronome/logging/InMemoryLogTracer.scala @@ -0,0 +1,50 @@ +package io.iohk.metronome.logging + +import cats.implicits._ +import cats.effect.Sync +import cats.effect.concurrent.Ref +import io.iohk.metronome.tracer.Tracer + +/** Collect logs in memory, so we can inspect them in tests. */ +object InMemoryLogTracer { + + class HybridLogTracer[F[_]: Sync]( + logRef: Ref[F, Vector[HybridLogObject]] + ) extends Tracer[F, HybridLogObject] { + + override def apply(a: => HybridLogObject): F[Unit] = + logRef.update(_ :+ a) + + def getLogs: F[Seq[HybridLogObject]] = + logRef.get.map(_.toSeq) + + def getLevel(l: HybridLogObject.Level) = + getLogs.map(_.filter(_.level == l)) + + def getErrors = getLevel(HybridLogObject.Level.Error) + def getWarns = getLevel(HybridLogObject.Level.Warn) + def getInfos = getLevel(HybridLogObject.Level.Info) + def getDebugs = getLevel(HybridLogObject.Level.Debug) + def getTraces = getLevel(HybridLogObject.Level.Trace) + } + + /** For example: + * + * ``` + * val logTracer = InMemoryLogTracer.hybrid[Task] + * val networkEventTracer = logTracer.contramap(implicitly[HybridLog[NetworkEvent]].apply _) + * val consensusEventTracer = logTracer.contramap(implicitly[HybridLog[ConsensusEvent]].apply _) + * + * val test = for { + * msg <- network.nextMessage + * _ <- consensus.handleMessage(msg) + * warns <- logTracer.getWarns + * } yield { + * warns shouldBe empty + * } + * + * ``` + */ + def hybrid[F[_]: Sync]: Tracer[F, HybridLogObject] = + new HybridLogTracer[F](Ref.unsafe[F, Vector[HybridLogObject]](Vector.empty)) +} diff --git a/metronome/logging/src/io/iohk/metronome/logging/LogTracer.scala b/metronome/logging/src/io/iohk/metronome/logging/LogTracer.scala new file mode 100644 index 00000000..aa1ab995 --- /dev/null +++ b/metronome/logging/src/io/iohk/metronome/logging/LogTracer.scala @@ -0,0 +1,38 @@ +package io.iohk.metronome.logging + +import cats.syntax.contravariant._ +import cats.effect.Sync +import io.circe.syntax._ +import io.iohk.metronome.tracer.Tracer +import org.slf4j.LoggerFactory + +/** Forward traces to SLF4J logs. */ +object LogTracer { + + /** Create a logger for `HybridLogObject` that delegates to SLF4J. */ + def hybrid[F[_]: Sync]: Tracer[F, HybridLogObject] = + new Tracer[F, HybridLogObject] { + override def apply(log: => HybridLogObject): F[Unit] = Sync[F].delay { + val logger = LoggerFactory.getLogger(log.source) + + def message = s"${log.message} ${log.event.asJson.noSpaces}" + + log.level match { + case HybridLogObject.Level.Error => + if (logger.isErrorEnabled) logger.error(message) + case HybridLogObject.Level.Warn => + if (logger.isWarnEnabled) logger.warn(message) + case HybridLogObject.Level.Info => + if (logger.isInfoEnabled) logger.info(message) + case HybridLogObject.Level.Debug => + if (logger.isDebugEnabled) logger.debug(message) + case HybridLogObject.Level.Trace => + if (logger.isTraceEnabled) logger.trace(message) + } + } + } + + /** Create a logger for a type that can be transformed to a `HybridLogObject`. */ + def hybrid[F[_]: Sync, T: HybridLog]: Tracer[F, T] = + hybrid[F].contramap(implicitly[HybridLog[T]].apply _) +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/ConnectionHandler.scala b/metronome/networking/src/io/iohk/metronome/networking/ConnectionHandler.scala new file mode 100644 index 00000000..adaa9c06 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/ConnectionHandler.scala @@ -0,0 +1,582 @@ +package io.iohk.metronome.networking + +import cats.effect.concurrent.Deferred +import cats.effect.implicits._ +import cats.effect.{Concurrent, ContextShift, Resource, Sync} +import cats.implicits._ +import io.iohk.metronome.networking.ConnectionHandler.HandledConnection._ +import io.iohk.metronome.networking.ConnectionHandler.{ + ConnectionAlreadyClosedException, + ConnectionWithConflictFlag, + FinishedConnection, + HandledConnection, + MessageReceived +} +import io.iohk.metronome.networking.EncryptedConnectionProvider.{ + ConnectionAlreadyClosed, + ConnectionError +} +import monix.catnap.ConcurrentQueue +import monix.execution.atomic.AtomicInt +import monix.tail.Iterant + +import java.net.InetSocketAddress +import scala.util.control.NoStackTrace +import scala.concurrent.duration._ +import java.time.Instant + +class ConnectionHandler[F[_]: Concurrent, K, M]( + connectionQueue: ConcurrentQueue[F, ConnectionWithConflictFlag[F, K, M]], + connectionsRegister: ConnectionsRegister[F, K, M], + messageQueue: ConcurrentQueue[F, MessageReceived[K, M]], + cancelToken: Deferred[F, Unit], + connectionFinishCallback: FinishedConnection[K] => F[Unit], + oppositeConnectionOverlap: FiniteDuration +)(implicit tracers: NetworkTracers[F, K, M]) { + + private val numberOfRunningConnections = AtomicInt(0) + + private def incrementRunningConnections: F[Unit] = { + Concurrent[F].delay(numberOfRunningConnections.increment()) + } + + private def decrementRunningConnections: F[Unit] = { + Concurrent[F].delay(numberOfRunningConnections.decrement()) + } + + private def closeAndDeregisterConnection( + handledConnection: HandledConnection[F, K, M] + ): F[Unit] = { + val close = for { + _ <- decrementRunningConnections + _ <- connectionsRegister.deregisterConnection(handledConnection) + _ <- handledConnection.close + } yield () + + close.guarantee { + tracers.deregistered(handledConnection) + } + } + + private def register( + possibleNewConnection: HandledConnection[F, K, M] + ): F[Unit] = { + connectionsRegister.registerIfAbsent(possibleNewConnection).flatMap { + maybeConflicting => + // in case of conflict we deal with it in the background + connectionQueue.offer( + (possibleNewConnection, maybeConflicting.isDefined) + ) + } + } + + /** Registers incoming connections and start handling incoming messages in background, in case connection is already handled + * it closes it + * + * @param serverAddress, server address of incoming connection which should already be known + * @param encryptedConnection, established connection + */ + def registerIncoming( + serverAddress: InetSocketAddress, + encryptedConnection: EncryptedConnection[F, K, M] + ): F[Unit] = { + HandledConnection + .incoming(cancelToken, serverAddress, encryptedConnection) + .flatMap(connection => register(connection)) + + } + + /** Registers out connections and start handling incoming messages in background, in case connection is already handled + * it closes it + * + * @param encryptedConnection, established connection + */ + def registerOutgoing( + encryptedConnection: EncryptedConnection[F, K, M] + ): F[Unit] = { + HandledConnection + .outgoing(cancelToken, encryptedConnection) + .flatMap(connection => register(connection)) + } + + /** Checks if handler already handles connection o peer with provided key + * + * @param connectionKey key of remote peer + */ + def isNewConnection(connectionKey: K): F[Boolean] = { + connectionsRegister.isNewConnection(connectionKey) + } + + /** Retrieves set of keys of all connected and handled peers + */ + def getAllActiveConnections: F[Set[K]] = + connectionsRegister.getAllRegisteredConnections.map { connections => + connections.map(_.key) + } + + /** Number of connections actively red in background + */ + def numberOfActiveConnections: F[Int] = { + Concurrent[F].delay(numberOfRunningConnections.get()) + } + + /** Stream of all messages received from all remote peers + */ + def incomingMessages: Iterant[F, MessageReceived[K, M]] = + Iterant.repeatEvalF(messageQueue.poll) + + /** Retrieves handled connection if one exists + * + * @param key, key of remote peer + */ + def getConnection(key: K): F[Option[HandledConnection[F, K, M]]] = + connectionsRegister.getConnection(key) + + /** Send message to remote peer if its connected + * + * @param recipient, key of the remote peer + * @param message message to send + */ + def sendMessage( + recipient: K, + message: M + ): F[Either[ConnectionAlreadyClosedException[K], Unit]] = { + getConnection(recipient).flatMap { + case Some(connection) => + connection + .sendMessage(message) + .attemptNarrow[ConnectionAlreadyClosed] + .flatMap { + case Left(_) => + // Closing the connection will cause it to be re-queued for reconnection. + tracers.sendError(connection) >> + connection.closeAlreadyClosed.as( + Left(ConnectionAlreadyClosedException(recipient)) + ) + + case Right(_) => + tracers.sent((connection, message)).as(Right(())) + } + case None => + Concurrent[F].pure(Left(ConnectionAlreadyClosedException(recipient))) + } + } + + private def handleConflict( + newConnectionWithPossibleConflict: ConnectionWithConflictFlag[F, K, M] + ): F[Option[HandledConnection[F, K, M]]] = { + val (newConnection, conflictHappened) = + newConnectionWithPossibleConflict + + if (conflictHappened) { + connectionsRegister.registerIfAbsent(newConnection).flatMap { + case Some(oldConnection) => + val replace = shouldReplaceConnection( + newConnection = newConnection, + oldConnection = oldConnection + ) + if (replace) { + replaceConnection(newConnection, oldConnection) + } else { + tracers.discarded(newConnection) >> newConnection.close.as(none) + } + case None => + // in the meantime between detection of conflict, and processing it old connection has dropped. Register new one + tracers.registered(newConnection) >> newConnection.some.pure[F] + } + } else { + tracers.registered(newConnection) >> newConnection.some.pure[F] + } + } + + /** Decide whether a new connection to/from a peer should replace an old connection from/to the same peer in case of a conflict. */ + private def shouldReplaceConnection( + newConnection: HandledConnection[F, K, M], + oldConnection: HandledConnection[F, K, M] + ): Boolean = { + if (oldConnection.age() < oppositeConnectionOverlap) { + // The old connection has just been created recently, yet we have a new connection already. + // Most likely the two nodes opened connections to each other around the same time, and if + // we close one of the connections connection based on direction, the node opposite will + // likely be doing the same to the _other_ connection, symmetrically. + // Instead, let's try to establish some ordering between the two, so the same connection + // is chosen as the victim on both sides. + val (newPort, oldPort) = ( + newConnection.ephemeralAddress.getPort, + oldConnection.ephemeralAddress.getPort + ) + newPort < oldPort || newPort == oldPort && + newConnection.ephemeralAddress.getHostName < oldConnection.ephemeralAddress.getHostName + } else { + newConnection.connectionDirection match { + case HandledConnection.IncomingConnection => + // Even though we have connection to this peer, they are calling us. One of the reason may be + // that they failed and we did not notice. Lets try to replace old connection with new one. + true + + case HandledConnection.OutgoingConnection => + // For some reason we were calling while we already have connection, most probably we have + // received incoming connection during call. Close this new connection, and keep the old one. + false + } + } + } + + /** Safely replaces old connection from remote peer with new connection with same remote peer. + * + * 1. The callback for old connection will not be called. As from the perspective of outside world connection is never + * finished + * 2. From the point of view of outside world connection never leaves connection registry i.e during replacing all call to + * registerOutgoing or registerIncoming will report conflicts to be handled + */ + private def replaceConnection( + newConnection: HandledConnection[F, K, M], + oldConnection: HandledConnection[F, K, M] + ): F[Option[HandledConnection[F, K, M]]] = { + for { + result <- oldConnection.requestReplace(newConnection) + maybeNew <- result match { + case ConnectionHandler.ReplaceFinished => + // Replace succeeded, old connection should already be closed and discarded, pass the new one forward + tracers.registered(newConnection) >> + newConnection.some.pure[F] + case ConnectionHandler.ConnectionAlreadyDisconnected => + // during or just before replace, old connection disconnected for some other reason, + // the reconnect call back will be fired either way so close the new connection + tracers.discarded(newConnection) >> + newConnection.close.as(None: Option[HandledConnection[F, K, M]]) + } + + } yield maybeNew + } + + private def callCallBackWithConnection( + handledConnection: HandledConnection[F, K, M] + ): F[Unit] = { + connectionFinishCallback( + FinishedConnection( + handledConnection.key, + handledConnection.serverAddress + ) + ) + } + + private def handleReplace( + replaceRequest: ReplaceRequested[F, K, M] + ): F[Unit] = { + connectionsRegister.replace(replaceRequest.newConnection).flatMap { + case Some(oldConnection) => + // close connection just in case someone who requested replace forgot it + oldConnection.close + case None => + // this case should not happen, as we handle each connection in separate fiber, and only this fiber can remove + // connection with given key. + ().pure[F] + } >> replaceRequest.signalReplaceSuccess + } + + private def handleConnectionFinish( + connection: HandledConnection[F, K, M] + ): F[Unit] = { + // at this point closeReason will always be filled + connection.getCloseReason.flatMap { + case HandledConnection.RemoteClosed => + closeAndDeregisterConnection( + connection + ) >> callCallBackWithConnection(connection) + case RemoteError(e) => + tracers.receiveError( + (connection, e) + ) >> closeAndDeregisterConnection( + connection + ) >> callCallBackWithConnection(connection) + case HandledConnection.ManagerShutdown => + closeAndDeregisterConnection(connection) + case replaceRequest: ReplaceRequested[F, K, M] => + // override old connection with new one, connection count is not changed, and callback is not called + handleReplace(replaceRequest) + } + } + + /** Connections multiplexer, it receives both incoming and outgoing connections and start reading incoming messages from + * them concurrently, putting them on received messages queue. + * In case of error or stream finish it cleans up all resources. + */ + private def handleConnections: F[Unit] = { + Iterant + .repeatEvalF(connectionQueue.poll) + .mapEval(handleConflict) + .collect { case Some(newConnection) => newConnection } + .mapEval { connection => + incrementRunningConnections >> + Iterant + .repeatEvalF( + connection.incomingMessage + ) + .takeWhile(_.isDefined) + .map(_.get) + .mapEval[Unit] { m => + tracers.received((connection, m)) >> + messageQueue.offer( + MessageReceived(connection.key, m) + ) + } + .guarantee( + handleConnectionFinish(connection) + ) + .completedL + .start + } + .completedL + } + + // for now shutdown of all connections is completed in background + private def shutdown: F[Unit] = cancelToken.complete(()).attempt.void +} + +object ConnectionHandler { + type ConnectionWithConflictFlag[F[_], K, M] = + (HandledConnection[F, K, M], Boolean) + + case class ConnectionAlreadyClosedException[K](key: K) + extends RuntimeException( + s"Connection with node ${key}, has already closed" + ) + with NoStackTrace + + private def getConnectionErrorMessage[K]( + e: ConnectionError, + connectionKey: K + ): String = { + e match { + case EncryptedConnectionProvider.DecodingError => + s"Unexpected decoding error on connection with ${connectionKey}" + case EncryptedConnectionProvider.UnexpectedError(ex) => + s"Unexpected error ${ex.getMessage} on connection with ${connectionKey}" + } + } + + case class UnexpectedConnectionError[K](e: ConnectionError, connectionKey: K) + extends RuntimeException(getConnectionErrorMessage(e, connectionKey)) + + case class MessageReceived[K, M](from: K, message: M) + + sealed abstract class ReplaceResult + case object ReplaceFinished extends ReplaceResult + case object ConnectionAlreadyDisconnected extends ReplaceResult + + /** Connection which is already handled by connection handler i.e it is registered in registry and handler is subscribed + * for incoming messages of that connection + * + * @param key, key of remote node + * @param serverAddress, address of the server of remote node. In case of incoming connection it will be different than + * the underlyingConnection remoteAddress, because we will look up the remote address based on the + * `key` in the cluster configuration. + * @param underlyingConnection, encrypted connection to send and receive messages + */ + class HandledConnection[F[_]: Concurrent, K, M] private ( + val connectionDirection: HandledConnectionDirection, + globalCancelToken: Deferred[F, Unit], + val key: K, + val serverAddress: InetSocketAddress, + underlyingConnection: EncryptedConnection[F, K, M], + closeReason: Deferred[F, HandledConnectionCloseReason] + ) { + private val createdAt = Instant.now() + + def age(): FiniteDuration = + (Instant.now().toEpochMilli() - createdAt.toEpochMilli()).millis + + /** For an incoming connection, this is the remote ephemeral address of the socket + * for an outgoing connection, it is the remote server address. + */ + def remoteAddress: InetSocketAddress = + underlyingConnection.remotePeerInfo._2 + + /** For an incoming connection, this is the local server address; + * for an outgoing connection, it is the local ephemeral address of the socket. + */ + def localAddress: InetSocketAddress = underlyingConnection.localAddress + + /** The client side address of the TCP socket. */ + def ephemeralAddress: InetSocketAddress = + connectionDirection match { + case IncomingConnection => remoteAddress + case OutgoingConnection => localAddress + } + + def sendMessage(m: M): F[Unit] = { + underlyingConnection.sendMessage(m) + } + + def close: F[Unit] = { + underlyingConnection.close + } + + def closeAlreadyClosed: F[Unit] = { + completeWithReason(RemoteClosed) >> underlyingConnection.close + } + + def requestReplace( + newConnection: HandledConnection[F, K, M] + ): F[ReplaceResult] = { + ReplaceRequested.requestReplace(newConnection).flatMap { request => + closeReason.complete(request).attempt.flatMap { + case Left(_) => + (ConnectionAlreadyDisconnected: ReplaceResult).pure[F] + case Right(_) => + underlyingConnection.close >> + request.waitForReplaceToFinish.as(ReplaceFinished: ReplaceResult) + } + } + } + + private def completeWithReason(r: HandledConnectionCloseReason): F[Unit] = + closeReason.complete(r).attempt.void + + def getCloseReason: F[HandledConnectionCloseReason] = closeReason.get + + private def handleIncomingEvent( + incomingEvent: Option[Either[ConnectionError, M]] + ): F[Option[M]] = { + incomingEvent match { + case Some(Right(m)) => m.some.pure[F] + case Some(Left(e)) => completeWithReason(RemoteError(e)).as(None) + case None => completeWithReason(RemoteClosed).as(None) + } + } + + def incomingMessage: F[Option[M]] = { + Concurrent[F] + .race(globalCancelToken.get, underlyingConnection.incomingMessage) + .flatMap { + case Left(_) => completeWithReason(ManagerShutdown).as(None) + case Right(e) => handleIncomingEvent(e) + } + } + } + + object HandledConnection { + sealed abstract class HandledConnectionCloseReason + case object RemoteClosed extends HandledConnectionCloseReason + case class RemoteError(e: ConnectionError) + extends HandledConnectionCloseReason + case object ManagerShutdown extends HandledConnectionCloseReason + class ReplaceRequested[F[_]: Sync, K, M]( + val newConnection: HandledConnection[F, K, M], + replaced: Deferred[F, Unit] + ) extends HandledConnectionCloseReason { + def signalReplaceSuccess: F[Unit] = replaced.complete(()).attempt.void + def waitForReplaceToFinish: F[Unit] = replaced.get + } + + object ReplaceRequested { + def requestReplace[F[_]: Concurrent, K, M]( + newConnection: HandledConnection[F, K, M] + ): F[ReplaceRequested[F, K, M]] = { + for { + signal <- Deferred[F, Unit] + } yield new ReplaceRequested(newConnection, signal) + } + } + + sealed abstract class HandledConnectionDirection + case object IncomingConnection extends HandledConnectionDirection + case object OutgoingConnection extends HandledConnectionDirection + + private def buildLifeCycleListener[F[_]: Concurrent] + : F[Deferred[F, HandledConnectionCloseReason]] = { + for { + closeReason <- Deferred[F, HandledConnectionCloseReason] + } yield closeReason + } + + private[ConnectionHandler] def outgoing[F[_]: Concurrent, K, M]( + globalCancelToken: Deferred[F, Unit], + encryptedConnection: EncryptedConnection[F, K, M] + ): F[HandledConnection[F, K, M]] = { + buildLifeCycleListener[F].map { closeReason => + new HandledConnection[F, K, M]( + OutgoingConnection, + globalCancelToken, + encryptedConnection.remotePeerInfo._1, + encryptedConnection.remotePeerInfo._2, + encryptedConnection, + closeReason + ) {} + } + } + + private[ConnectionHandler] def incoming[F[_]: Concurrent, K, M]( + globalCancelToken: Deferred[F, Unit], + serverAddress: InetSocketAddress, + encryptedConnection: EncryptedConnection[F, K, M] + ): F[HandledConnection[F, K, M]] = { + buildLifeCycleListener[F].map { closeReason => + new HandledConnection[F, K, M]( + IncomingConnection, + globalCancelToken, + encryptedConnection.remotePeerInfo._1, + serverAddress, + encryptedConnection, + closeReason + ) {} + } + } + + } + + private def buildHandler[F[_]: Concurrent: ContextShift, K, M]( + connectionFinishCallback: FinishedConnection[K] => F[Unit], + oppositeConnectionOverlap: FiniteDuration + )(implicit + tracers: NetworkTracers[F, K, M] + ): F[ConnectionHandler[F, K, M]] = { + for { + cancelToken <- Deferred[F, Unit] + acquiredConnections <- ConnectionsRegister.empty[F, K, M] + messageQueue <- ConcurrentQueue.unbounded[F, MessageReceived[K, M]]() + connectionQueue <- ConcurrentQueue + .unbounded[F, ConnectionWithConflictFlag[F, K, M]]() + } yield new ConnectionHandler[F, K, M]( + connectionQueue, + acquiredConnections, + messageQueue, + cancelToken, + connectionFinishCallback, + oppositeConnectionOverlap + ) + } + + case class FinishedConnection[K]( + connectionKey: K, + connectionServerAddress: InetSocketAddress + ) + + /** Starts connection handler, and polling form connections + * + * @param connectionFinishCallback, callback to be called when connection is finished and get de-registered + */ + def apply[F[_]: Concurrent: ContextShift, K, M]( + connectionFinishCallback: FinishedConnection[K] => F[Unit], + oppositeConnectionOverlap: FiniteDuration + )(implicit + tracers: NetworkTracers[F, K, M] + ): Resource[F, ConnectionHandler[F, K, M]] = { + Resource + .make( + buildHandler[F, K, M]( + connectionFinishCallback, + oppositeConnectionOverlap + ) + ) { handler => + handler.shutdown + } + .flatMap { handler => + for { + _ <- handler.handleConnections.background + } yield handler + } + } + +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/ConnectionsRegister.scala b/metronome/networking/src/io/iohk/metronome/networking/ConnectionsRegister.scala new file mode 100644 index 00000000..59ad6620 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/ConnectionsRegister.scala @@ -0,0 +1,63 @@ +package io.iohk.metronome.networking + +import cats.effect.Concurrent +import cats.effect.concurrent.Ref +import io.iohk.metronome.networking.ConnectionHandler.HandledConnection +import cats.implicits._ + +class ConnectionsRegister[F[_]: Concurrent, K, M]( + registerRef: Ref[F, Map[K, HandledConnection[F, K, M]]] +) { + + def registerIfAbsent( + connection: HandledConnection[F, K, M] + ): F[Option[HandledConnection[F, K, M]]] = { + registerRef.modify { register => + val connectionKey = connection.key + + if (register.contains(connectionKey)) { + (register, register.get(connectionKey)) + } else { + (register.updated(connectionKey, connection), None) + } + } + } + + def isNewConnection(connectionKey: K): F[Boolean] = { + registerRef.get.map(register => !register.contains(connectionKey)) + } + + def deregisterConnection( + connection: HandledConnection[F, K, M] + ): F[Unit] = { + registerRef.update(register => register - (connection.key)) + } + + def getAllRegisteredConnections: F[Set[HandledConnection[F, K, M]]] = { + registerRef.get.map(register => register.values.toSet) + } + + def getConnection( + connectionKey: K + ): F[Option[HandledConnection[F, K, M]]] = + registerRef.get.map(register => register.get(connectionKey)) + + def replace( + connection: HandledConnection[F, K, M] + ): F[Option[HandledConnection[F, K, M]]] = { + registerRef.modify { register => + register.updated(connection.key, connection) -> register.get( + connection.key + ) + } + } + +} + +object ConnectionsRegister { + def empty[F[_]: Concurrent, K, M]: F[ConnectionsRegister[F, K, M]] = { + Ref + .of(Map.empty[K, HandledConnection[F, K, M]]) + .map(ref => new ConnectionsRegister[F, K, M](ref)) + } +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/EncryptedConnectionProvider.scala b/metronome/networking/src/io/iohk/metronome/networking/EncryptedConnectionProvider.scala new file mode 100644 index 00000000..9e0ebbf2 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/EncryptedConnectionProvider.scala @@ -0,0 +1,38 @@ +package io.iohk.metronome.networking + +import io.iohk.metronome.networking.EncryptedConnectionProvider.{ + ConnectionError, + HandshakeFailed +} + +import java.net.InetSocketAddress + +trait EncryptedConnection[F[_], K, M] { + def localAddress: InetSocketAddress + def remotePeerInfo: (K, InetSocketAddress) + def sendMessage(m: M): F[Unit] + def incomingMessage: F[Option[Either[ConnectionError, M]]] + def close: F[Unit] +} + +trait EncryptedConnectionProvider[F[_], K, M] { + def localPeerInfo: (K, InetSocketAddress) + def connectTo( + k: K, + address: InetSocketAddress + ): F[EncryptedConnection[F, K, M]] + def incomingConnection + : F[Option[Either[HandshakeFailed, EncryptedConnection[F, K, M]]]] +} + +object EncryptedConnectionProvider { + case class HandshakeFailed(ex: Throwable, remoteAddress: InetSocketAddress) + + sealed trait ConnectionError + case object DecodingError extends ConnectionError + case class UnexpectedError(ex: Throwable) extends ConnectionError + + case class ConnectionAlreadyClosed(address: InetSocketAddress) + extends RuntimeException + +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/LocalConnectionManager.scala b/metronome/networking/src/io/iohk/metronome/networking/LocalConnectionManager.scala new file mode 100644 index 00000000..7b117542 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/LocalConnectionManager.scala @@ -0,0 +1,57 @@ +package io.iohk.metronome.networking + +import cats.implicits._ +import cats.effect.{Concurrent, Timer, Resource, ContextShift} +import java.net.InetSocketAddress +import monix.eval.{TaskLift, TaskLike} +import monix.tail.Iterant +import scodec.Codec + +trait LocalConnectionManager[F[_], K, M] { + def isConnected: F[Boolean] + def incomingMessages: Iterant[F, M] + def sendMessage( + message: M + ): F[Either[ConnectionHandler.ConnectionAlreadyClosedException[K], Unit]] +} + +/** Connect to a single local process and keep the connection alive. */ +object LocalConnectionManager { + + def apply[ + F[_]: Concurrent: TaskLift: TaskLike: Timer: ContextShift, + K: Codec, + M: Codec + ]( + encryptedConnectionsProvider: EncryptedConnectionProvider[F, K, M], + targetKey: K, + targetAddress: InetSocketAddress, + retryConfig: RemoteConnectionManager.RetryConfig + )(implicit + tracers: NetworkTracers[F, K, M] + ): Resource[F, LocalConnectionManager[F, K, M]] = { + for { + remoteConnectionManager <- RemoteConnectionManager[F, K, M]( + encryptedConnectionsProvider, + RemoteConnectionManager.ClusterConfig[K]( + Set(targetKey -> targetAddress) + ), + retryConfig + ) + localConnectionManager = new LocalConnectionManager[F, K, M] { + override def isConnected = + remoteConnectionManager.getAcquiredConnections.map( + _.contains(targetKey) + ) + + override def incomingMessages = + remoteConnectionManager.incomingMessages.map { + case ConnectionHandler.MessageReceived(_, m) => m + } + + override def sendMessage(message: M) = + remoteConnectionManager.sendMessage(targetKey, message) + } + } yield localConnectionManager + } +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/Network.scala b/metronome/networking/src/io/iohk/metronome/networking/Network.scala new file mode 100644 index 00000000..9220ef3f --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/Network.scala @@ -0,0 +1,85 @@ +package io.iohk.metronome.networking + +import cats.implicits._ +import cats.effect.{Sync, Resource, Concurrent, ContextShift} +import io.iohk.metronome.networking.ConnectionHandler.MessageReceived +import monix.tail.Iterant +import monix.catnap.ConcurrentQueue + +/** Network adapter for specializing messages. */ +trait Network[F[_], K, M] { + + /** Receive incoming messages from the network. */ + def incomingMessages: Iterant[F, MessageReceived[K, M]] + + /** Try sending a message to a federation member, if we are connected. */ + def sendMessage(recipient: K, message: M): F[Unit] +} + +object Network { + + def fromRemoteConnnectionManager[F[_]: Sync, K, M]( + manager: RemoteConnectionManager[F, K, M] + ): Network[F, K, M] = new Network[F, K, M] { + override def incomingMessages = + manager.incomingMessages + + override def sendMessage(recipient: K, message: M) = + // Not returning an error if we are trying to send to someone no longer connected, + // this should be handled transparently, delivery is best-effort. + manager.sendMessage(recipient, message).void + } + + /** Consume messges from a network and dispatch them either left or right, + * based on a splitter function. Combine messages the other way. + */ + def splitter[F[_]: Concurrent: ContextShift, K, M, L, R]( + network: Network[F, K, M] + )( + split: M => Either[L, R], + merge: Either[L, R] => M + ): Resource[F, (Network[F, K, L], Network[F, K, R])] = + for { + leftQueue <- makeQueue[F, K, L] + rightQueue <- makeQueue[F, K, R] + + _ <- Concurrent[F].background { + network.incomingMessages.mapEval { + case MessageReceived(from, message) => + split(message) match { + case Left(leftMessage) => + leftQueue.offer(MessageReceived(from, leftMessage)) + case Right(rightMessage) => + rightQueue.offer(MessageReceived(from, rightMessage)) + } + }.completedL + } + + leftNetwork = new SplitNetwork[F, K, L]( + leftQueue.poll, + (r, m) => network.sendMessage(r, merge(Left(m))) + ) + + rightNetwork = new SplitNetwork[F, K, R]( + rightQueue.poll, + (r, m) => network.sendMessage(r, merge(Right(m))) + ) + + } yield (leftNetwork, rightNetwork) + + private def makeQueue[F[_]: Concurrent: ContextShift, K, M] = + Resource.liftF { + ConcurrentQueue.unbounded[F, MessageReceived[K, M]](None) + } + + private class SplitNetwork[F[_]: Sync, K, M]( + poll: F[MessageReceived[K, M]], + send: (K, M) => F[Unit] + ) extends Network[F, K, M] { + override def incomingMessages: Iterant[F, MessageReceived[K, M]] = + Iterant.repeatEvalF(poll) + + def sendMessage(recipient: K, message: M) = + send(recipient, message) + } +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/NetworkEvent.scala b/metronome/networking/src/io/iohk/metronome/networking/NetworkEvent.scala new file mode 100644 index 00000000..9b33f609 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/NetworkEvent.scala @@ -0,0 +1,61 @@ +package io.iohk.metronome.networking + +import java.net.InetSocketAddress + +/** Events we want to trace. */ +sealed trait NetworkEvent[K, +M] + +object NetworkEvent { + import ConnectionHandler.HandledConnection.HandledConnectionDirection + + case class Peer[K](key: K, address: InetSocketAddress) + + /** The connection to/from the peer has been added to the register. */ + case class ConnectionRegistered[K]( + peer: Peer[K], + direction: HandledConnectionDirection + ) extends NetworkEvent[K, Nothing] + + /** The connection to/from the peer has been closed and removed from the register. */ + case class ConnectionDeregistered[K]( + peer: Peer[K], + direction: HandledConnectionDirection + ) extends NetworkEvent[K, Nothing] + + /** We had two connections to/from the peer and discarded one of them. */ + case class ConnectionDiscarded[K]( + peer: Peer[K], + direction: HandledConnectionDirection + ) extends NetworkEvent[K, Nothing] + + /** Failed to establish connection to remote peer. */ + case class ConnectionFailed[K]( + peer: Peer[K], + numberOfFailures: Int, + error: Throwable + ) extends NetworkEvent[K, Nothing] + + /** Error reading data from a connection. */ + case class ConnectionReceiveError[K]( + peer: Peer[K], + error: EncryptedConnectionProvider.ConnectionError + ) extends NetworkEvent[K, Nothing] + + /** Error sending data over a connection, already disconnected. */ + case class ConnectionSendError[K]( + peer: Peer[K] + ) extends NetworkEvent[K, Nothing] + + /** Incoming connection from someone outside the federation. */ + case class ConnectionUnknown[K](peer: Peer[K]) + extends NetworkEvent[K, Nothing] + + /** Received incoming message from peer. */ + case class MessageReceived[K, M](peer: Peer[K], message: M) + extends NetworkEvent[K, M] + + /** Sent outgoing message to peer. */ + case class MessageSent[K, M](peer: Peer[K], message: M) + extends NetworkEvent[K, M] + +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/NetworkTracers.scala b/metronome/networking/src/io/iohk/metronome/networking/NetworkTracers.scala new file mode 100644 index 00000000..0e75bb3f --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/NetworkTracers.scala @@ -0,0 +1,80 @@ +package io.iohk.metronome.networking + +import cats.implicits._ +import io.iohk.metronome.tracer.Tracer + +case class NetworkTracers[F[_], K, M]( + unknown: Tracer[F, EncryptedConnection[F, K, M]], + registered: Tracer[F, ConnectionHandler.HandledConnection[F, K, M]], + deregistered: Tracer[F, ConnectionHandler.HandledConnection[F, K, M]], + discarded: Tracer[F, ConnectionHandler.HandledConnection[F, K, M]], + failed: Tracer[F, RemoteConnectionManager.ConnectionFailure[K]], + receiveError: Tracer[F, NetworkTracers.HandledConnectionError[F, K, M]], + sendError: Tracer[F, ConnectionHandler.HandledConnection[F, K, M]], + received: Tracer[F, NetworkTracers.HandledConnectionMessage[F, K, M]], + sent: Tracer[F, NetworkTracers.HandledConnectionMessage[F, K, M]] +) + +object NetworkTracers { + import NetworkEvent._ + import ConnectionHandler.HandledConnection + + type HandledConnectionError[F[_], K, M] = ( + ConnectionHandler.HandledConnection[F, K, M], + EncryptedConnectionProvider.ConnectionError + ) + type HandledConnectionMessage[F[_], K, M] = ( + ConnectionHandler.HandledConnection[F, K, M], + M + ) + + def apply[F[_], K, M]( + tracer: Tracer[F, NetworkEvent[K, M]] + ): NetworkTracers[F, K, M] = + NetworkTracers[F, K, M]( + unknown = tracer.contramap[EncryptedConnection[F, K, M]] { conn => + ConnectionUnknown((Peer.apply[K] _).tupled(conn.remotePeerInfo)) + }, + registered = tracer.contramap[HandledConnection[F, K, M]] { conn => + ConnectionRegistered( + Peer(conn.key, conn.serverAddress), + conn.connectionDirection + ) + }, + deregistered = tracer.contramap[HandledConnection[F, K, M]] { conn => + ConnectionDeregistered( + Peer(conn.key, conn.serverAddress), + conn.connectionDirection + ) + }, + discarded = tracer.contramap[HandledConnection[F, K, M]] { conn => + ConnectionDiscarded( + Peer(conn.key, conn.serverAddress), + conn.connectionDirection + ) + }, + failed = + tracer.contramap[RemoteConnectionManager.ConnectionFailure[K]] { fail => + ConnectionFailed( + Peer(fail.connectionRequest.key, fail.connectionRequest.address), + fail.connectionRequest.numberOfFailures, + fail.err + ) + }, + receiveError = + tracer.contramap[HandledConnectionError[F, K, M]] { case (conn, err) => + ConnectionReceiveError(Peer(conn.key, conn.serverAddress), err) + }, + sendError = tracer.contramap[HandledConnection[F, K, M]] { conn => + ConnectionSendError(Peer(conn.key, conn.serverAddress)) + }, + received = tracer.contramap[HandledConnectionMessage[F, K, M]] { + case (conn, msg) => + MessageReceived(Peer(conn.key, conn.serverAddress), msg) + }, + sent = tracer.contramap[HandledConnectionMessage[F, K, M]] { + case (conn, msg) => + MessageSent(Peer(conn.key, conn.serverAddress), msg) + } + ) +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/RemoteConnectionManager.scala b/metronome/networking/src/io/iohk/metronome/networking/RemoteConnectionManager.scala new file mode 100644 index 00000000..21973336 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/RemoteConnectionManager.scala @@ -0,0 +1,362 @@ +package io.iohk.metronome.networking + +import cats.effect.implicits._ +import cats.effect.{Concurrent, ContextShift, Resource, Sync, Timer} +import cats.implicits._ +import io.iohk.metronome.networking.ConnectionHandler.{ + FinishedConnection, + MessageReceived +} +import io.iohk.metronome.networking.RemoteConnectionManager.RetryConfig.RandomJitterConfig +import monix.catnap.ConcurrentQueue +import monix.eval.{TaskLift, TaskLike} +import monix.reactive.Observable +import monix.tail.Iterant +import scodec.Codec + +import java.net.InetSocketAddress +import java.util.concurrent.{ThreadLocalRandom, TimeUnit} +import scala.concurrent.duration.FiniteDuration + +class RemoteConnectionManager[F[_]: Sync, K, M: Codec]( + connectionHandler: ConnectionHandler[F, K, M], + localInfo: (K, InetSocketAddress) +) { + + def getLocalPeerInfo: (K, InetSocketAddress) = localInfo + + def getAcquiredConnections: F[Set[K]] = { + connectionHandler.getAllActiveConnections + } + + def incomingMessages: Iterant[F, MessageReceived[K, M]] = + connectionHandler.incomingMessages + + def sendMessage( + recipient: K, + message: M + ): F[Either[ConnectionHandler.ConnectionAlreadyClosedException[K], Unit]] = { + connectionHandler.sendMessage(recipient, message) + } +} + +object RemoteConnectionManager { + case class ConnectionSuccess[F[_], K, M]( + encryptedConnection: EncryptedConnection[F, K, M] + ) + + case class ConnectionFailure[K]( + connectionRequest: OutGoingConnectionRequest[K], + err: Throwable + ) + + private def connectTo[ + F[_]: Sync, + K: Codec, + M: Codec + ]( + encryptedConnectionProvider: EncryptedConnectionProvider[F, K, M], + connectionRequest: OutGoingConnectionRequest[K] + ): F[Either[ConnectionFailure[K], ConnectionSuccess[F, K, M]]] = { + encryptedConnectionProvider + .connectTo(connectionRequest.key, connectionRequest.address) + .redeemWith( + e => Sync[F].pure(Left(ConnectionFailure(connectionRequest, e))), + connection => Sync[F].pure(Right(ConnectionSuccess(connection))) + ) + } + + case class RetryConfig( + initialDelay: FiniteDuration, + backOffFactor: Long, + maxDelay: FiniteDuration, + randomJitterConfig: RandomJitterConfig, + oppositeConnectionOverlap: FiniteDuration + ) + + object RetryConfig { + sealed abstract case class RandomJitterConfig private ( + fractionOfDelay: Double + ) + + object RandomJitterConfig { + import scala.concurrent.duration._ + + /** Build random jitter config + * @param fractionOfTheDelay, in what range in the computed jitter should lay, it should by in range 0..1 + */ + def buildJitterConfig( + fractionOfTheDelay: Double + ): Option[RandomJitterConfig] = { + if (fractionOfTheDelay >= 0 && fractionOfTheDelay <= 1) { + Some(new RandomJitterConfig(fractionOfTheDelay) {}) + } else { + None + } + } + + /** computes new duration with additional random jitter added. Works with millisecond precision i.e if provided duration + * will be less than 1 millisecond then no jitter will be added + * @param config,jitter config + * @param delay, duration to randomize it should positive number otherwise no randomization will happen + */ + def randomizeWithJitter( + config: RandomJitterConfig, + delay: FiniteDuration + ): FiniteDuration = { + val fractionDuration = + (delay.max(0.milliseconds) * config.fractionOfDelay).toMillis + if (fractionDuration == 0) { + delay + } else { + val randomized = ThreadLocalRandom + .current() + .nextLong(-fractionDuration, fractionDuration) + val randomFactor = FiniteDuration(randomized, TimeUnit.MILLISECONDS) + delay + randomFactor + } + } + + /** Default jitter config which will keep random jitter in +/-20% range + */ + val defaultConfig: RandomJitterConfig = buildJitterConfig(0.2).get + } + + import scala.concurrent.duration._ + def default: RetryConfig = { + RetryConfig( + initialDelay = 500.milliseconds, + backOffFactor = 2, + maxDelay = 30.seconds, + randomJitterConfig = RandomJitterConfig.defaultConfig, + oppositeConnectionOverlap = 1.second + ) + } + + } + + private def retryConnection[F[_]: Timer: Concurrent, K]( + config: RetryConfig, + failedConnectionRequest: OutGoingConnectionRequest[K] + ): F[OutGoingConnectionRequest[K]] = { + val updatedFailureCount = + failedConnectionRequest.numberOfFailures + 1 + val exponentialBackoff = + math.pow(config.backOffFactor.toDouble, updatedFailureCount).toLong + + val newDelay = + ((config.initialDelay * exponentialBackoff).min(config.maxDelay)) + + val newDelayWithJitter = RandomJitterConfig.randomizeWithJitter( + config.randomJitterConfig, + newDelay + ) + + Timer[F] + .sleep(newDelayWithJitter) + .as(failedConnectionRequest.copy(numberOfFailures = updatedFailureCount)) + + } + + /** Connections are acquired in linear fashion i.e there can be at most one concurrent call to remote peer. + * In case of failure each connection will be retried infinite number of times with exponential backoff between + * each call. + */ + private def acquireConnections[ + F[_]: Concurrent: TaskLift: TaskLike: Timer, + K: Codec, + M: Codec + ]( + encryptedConnectionProvider: EncryptedConnectionProvider[F, K, M], + connectionsToAcquire: ConcurrentQueue[F, OutGoingConnectionRequest[K]], + connectionsHandler: ConnectionHandler[F, K, M], + retryConfig: RetryConfig + )(implicit tracers: NetworkTracers[F, K, M]): F[Unit] = { + + def connectWithErrors( + connectionToAcquire: OutGoingConnectionRequest[K] + ): F[Either[ConnectionFailure[K], Unit]] = { + connectTo(encryptedConnectionProvider, connectionToAcquire).flatMap { + case Left(err) => + Concurrent[F].pure(Left(err)) + case Right(connection) => + connectionsHandler + .registerOutgoing(connection.encryptedConnection) + .as(Right(())) + } + } + + /** Observable is used here as streaming primitive as it has richer api than Iterant and have mapParallelUnorderedF + * combinator, which makes it possible to have multiple concurrent retry timers, which are cancelled when whole + * outer stream is cancelled + */ + Observable + .repeatEvalF(connectionsToAcquire.poll) + .filterEvalF(request => connectionsHandler.isNewConnection(request.key)) + .mapEvalF(connectWithErrors) + .mapParallelUnorderedF(Integer.MAX_VALUE) { + case Left(failure) => + tracers.failed(failure) >> + retryConnection(retryConfig, failure.connectionRequest).flatMap( + updatedRequest => connectionsToAcquire.offer(updatedRequest) + ) + case Right(_) => + Concurrent[F].pure(()) + } + .completedF + } + + /** Reads incoming connections in linear fashion and check if they are on cluster allowed list. + */ + private def handleServerConnections[F[_]: Concurrent: TaskLift, K, M: Codec]( + pg: EncryptedConnectionProvider[F, K, M], + connectionsHandler: ConnectionHandler[F, K, M], + clusterConfig: ClusterConfig[K] + )(implicit tracers: NetworkTracers[F, K, M]): F[Unit] = { + Iterant + .repeatEvalF(pg.incomingConnection) + .takeWhile(_.isDefined) + .map(_.get) + .collect { case Right(value) => + value + } + .mapEval { encryptedConnection => + clusterConfig.getIncomingConnectionServerInfo( + encryptedConnection.remotePeerInfo._1 + ) match { + case Some(incomingConnectionServerAddress) => + connectionsHandler.registerIncoming( + incomingConnectionServerAddress, + encryptedConnection + ) + + case None => + // unknown connection, just close it + tracers.unknown(encryptedConnection) >> + encryptedConnection.close + } + } + .completedL + } + + class HandledConnectionFinisher[F[_]: Concurrent: Timer, K, M]( + connectionsToAcquire: ConcurrentQueue[F, OutGoingConnectionRequest[K]], + retryConfig: RetryConfig + ) { + def finish(finishedConnection: FinishedConnection[K]): F[Unit] = { + retryConnection( + retryConfig, + OutGoingConnectionRequest.initial( + finishedConnection.connectionKey, + finishedConnection.connectionServerAddress + ) + ).flatMap(req => connectionsToAcquire.offer(req)) + } + } + + case class OutGoingConnectionRequest[K]( + key: K, + address: InetSocketAddress, + numberOfFailures: Int + ) + + object OutGoingConnectionRequest { + def initial[K]( + key: K, + address: InetSocketAddress + ): OutGoingConnectionRequest[K] = { + OutGoingConnectionRequest(key, address, 0) + } + } + + case class ClusterConfig[K]( + clusterNodes: Set[(K, InetSocketAddress)] + ) { + val clusterNodesKeys = clusterNodes.map(_._1) + + val serverAddresses = clusterNodes.toMap + + def isAllowedIncomingConnection(k: K): Boolean = + clusterNodesKeys.contains(k) + + def getIncomingConnectionServerInfo(k: K): Option[InetSocketAddress] = + serverAddresses.get(k) + + } + + /** Connection manager for static topology cluster. It starts 3 concurrent backgrounds processes: + * 1. Calling process - tries to connect to remote nodes specified in cluster config. In case of failure, retries with + * exponential backoff. + * 2. Server process - reads incoming connections from server socket. Validates that incoming connections is from known + * remote peer specified in cluster config. + * 3. Message reading process - receives connections from both, Calling and Server processes, and for each connections + * start concurrent process reading messages from those connections. In case of some error on connections, it closes + * connection. In case of discovering that one of outgoing connections failed, it request Calling process to establish + * connection once again. + * + * @param encryptedConnectionsProvider component which makes it possible to receive and acquire encrypted connections + * @param clusterConfig static cluster topology configuration + * @param retryConfig retry configuration for outgoing connections (incoming connections are not retried) + */ + def apply[ + F[_]: Concurrent: TaskLift: TaskLike: Timer: ContextShift, + K: Codec, + M: Codec + ]( + encryptedConnectionsProvider: EncryptedConnectionProvider[F, K, M], + clusterConfig: ClusterConfig[K], + retryConfig: RetryConfig + )(implicit + tracers: NetworkTracers[F, K, M] + ): Resource[F, RemoteConnectionManager[F, K, M]] = { + for { + connectionsToAcquireQueue <- Resource.liftF( + ConcurrentQueue.unbounded[F, OutGoingConnectionRequest[K]]() + ) + _ <- Resource.liftF( + connectionsToAcquireQueue.offerMany( + clusterConfig.clusterNodes.collect { + case toConnect + if toConnect != encryptedConnectionsProvider.localPeerInfo => + OutGoingConnectionRequest.initial(toConnect._1, toConnect._2) + } + ) + ) + + handledConnectionFinisher = new HandledConnectionFinisher[F, K, M]( + connectionsToAcquireQueue, + retryConfig + ) + + connectionsHandler <- ConnectionHandler.apply[F, K, M]( + // when each connection will finished it the callback will be called, + // and connection will be put to connections to acquire queue + handledConnectionFinisher.finish, + // A duration where we consider the possibilty that both nodes opened + // connections against each other at the same time, and they should try + // to determinstically pick the same one to close. After this time, + // we interpret duplicate connections as repairing a failure the other + // side has detected, but we haven't yet. + oppositeConnectionOverlap = retryConfig.oppositeConnectionOverlap + ) + + _ <- acquireConnections( + encryptedConnectionsProvider, + connectionsToAcquireQueue, + connectionsHandler, + retryConfig + ).background + + _ <- handleServerConnections( + encryptedConnectionsProvider, + connectionsHandler, + clusterConfig + ).background + + } yield new RemoteConnectionManager[F, K, M]( + connectionsHandler, + encryptedConnectionsProvider.localPeerInfo + ) + + } +} diff --git a/metronome/networking/src/io/iohk/metronome/networking/ScalanetConnectionProvider.scala b/metronome/networking/src/io/iohk/metronome/networking/ScalanetConnectionProvider.scala new file mode 100644 index 00000000..07a5e5e2 --- /dev/null +++ b/metronome/networking/src/io/iohk/metronome/networking/ScalanetConnectionProvider.scala @@ -0,0 +1,192 @@ +package io.iohk.metronome.networking + +import cats.effect.{Resource, Sync} +import io.iohk.metronome.crypto.ECKeyPair +import io.iohk.metronome.networking.EncryptedConnectionProvider.{ + ConnectionAlreadyClosed, + ConnectionError, + DecodingError, + HandshakeFailed, + UnexpectedError +} +import io.iohk.scalanet.peergroup.PeerGroup.{ + ChannelBrokenException, + ServerEvent +} +import io.iohk.scalanet.peergroup.dynamictls.DynamicTLSPeerGroup.{ + Config, + FramingConfig, + PeerInfo +} +import io.iohk.scalanet.peergroup.dynamictls.{DynamicTLSPeerGroup, Secp256k1} +import io.iohk.scalanet.peergroup.{Channel, InetMultiAddress} +import monix.eval.{Task, TaskLift} +import monix.execution.Scheduler +import scodec.Codec + +import java.net.InetSocketAddress +import java.security.SecureRandom + +object ScalanetConnectionProvider { + private class ScalanetEncryptedConnection[F[_]: TaskLift, K: Codec, M: Codec]( + underlyingChannel: Channel[PeerInfo, M], + underlyingChannelRelease: F[Unit], + channelKey: K + ) extends EncryptedConnection[F, K, M] { + + override def close: F[Unit] = underlyingChannelRelease + + override val localAddress: InetSocketAddress = ( + underlyingChannel.from.address.inetSocketAddress + ) + + override val remotePeerInfo: (K, InetSocketAddress) = ( + channelKey, + underlyingChannel.to.address.inetSocketAddress + ) + + override def sendMessage(m: M): F[Unit] = { + TaskLift[F].apply(underlyingChannel.sendMessage(m).onErrorRecoverWith { + case _: ChannelBrokenException[_] => + Task.raiseError( + ConnectionAlreadyClosed( + underlyingChannel.to.address.inetSocketAddress + ) + ) + }) + } + + override def incomingMessage: F[Option[Either[ConnectionError, M]]] = + TaskLift[F].apply(nextNonIdleMessage) + + private val nextNonIdleMessage: Task[Option[Either[ConnectionError, M]]] = { + underlyingChannel.nextChannelEvent.flatMap { + case Some(event) => + event match { + case Channel.MessageReceived(m) => + Task.pure(Some(Right(m))) + case Channel.UnexpectedError(e) => + Task.pure(Some(Left(UnexpectedError(e)))) + case Channel.DecodingError => + Task.pure(Some(Left(DecodingError))) + case Channel.ChannelIdle(_, _) => + nextNonIdleMessage + } + case None => + Task.pure(None) + } + } + } + + private object ScalanetEncryptedConnection { + def apply[F[_]: TaskLift, K: Codec, M: Codec]( + channel: Channel[PeerInfo, M], + channelRelease: Task[Unit] + ): Task[EncryptedConnection[F, K, M]] = { + + Task + .fromTry(Codec[K].decodeValue(channel.to.id).toTry) + .map { key => + new ScalanetEncryptedConnection[F, K, M]( + channel, + TaskLift[F].apply(channelRelease), + key + ) + } + .onErrorHandleWith { e => + channelRelease.flatMap(_ => Task.raiseError(e)) + } + + } + + } + + // Codec constraint for K is necessary as scalanet require peer key to be in BitVector format + def scalanetProvider[F[_]: Sync: TaskLift, K: Codec, M: Codec]( + bindAddress: InetSocketAddress, + nodeKeyPair: ECKeyPair, + secureRandom: SecureRandom, + useNativeTlsImplementation: Boolean, + framingConfig: FramingConfig, + maxIncomingQueueSizePerPeer: Int + )(implicit + sch: Scheduler + ): Resource[F, EncryptedConnectionProvider[F, K, M]] = { + for { + config <- Resource.liftF[F, Config]( + Sync[F].fromTry( + DynamicTLSPeerGroup + .Config( + bindAddress, + Secp256k1, + nodeKeyPair.underlying, + secureRandom, + useNativeTlsImplementation, + framingConfig, + maxIncomingQueueSizePerPeer, + incomingConnectionsThrottling = None, + stalePeerDetectionConfig = None + ) + ) + ) + pg <- DynamicTLSPeerGroup[M](config).mapK(TaskLift.apply) + local <- Resource.pure( + ( + Codec[K].decodeValue(pg.processAddress.id).require, + pg.processAddress.address.inetSocketAddress + ) + ) + + } yield new EncryptedConnectionProvider[F, K, M] { + override def localPeerInfo: (K, InetSocketAddress) = local + + import cats.implicits._ + + /** Connects to remote node, creating new connection with each call + * + * @param k, key of the remote node + * @param address, address of the remote node + */ + override def connectTo( + k: K, + address: InetSocketAddress + ): F[EncryptedConnection[F, K, M]] = { + val encodedKey = Codec[K].encode(k).require + pg.client(PeerInfo(encodedKey, InetMultiAddress(address))) + .mapK[Task, F](TaskLift[F]) + .allocated + .map { case (channel, release) => + new ScalanetEncryptedConnection(channel, release, k) + } + } + + override def incomingConnection + : F[Option[Either[HandshakeFailed, EncryptedConnection[F, K, M]]]] = { + TaskLift[F].apply(pg.nextServerEvent.flatMap { + case Some(ev) => + ev match { + case ServerEvent.ChannelCreated(channel, release) => + ScalanetEncryptedConnection[F, K, M](channel, release).map { + connection => + Some(Right(connection)) + } + + case ServerEvent.HandshakeFailed(failure) => + Task.now( + Some( + Left( + HandshakeFailed( + failure, + failure.to.address.inetSocketAddress + ) + ) + ) + ) + + } + case None => Task.now(None) + }) + } + } + } +} diff --git a/metronome/networking/test/resources/logback.xml b/metronome/networking/test/resources/logback.xml new file mode 100644 index 00000000..d3406af3 --- /dev/null +++ b/metronome/networking/test/resources/logback.xml @@ -0,0 +1,18 @@ + + + + + + %d{HH:mm:ss.SSS} %-5level %logger{36} %msg%n + + + + + + + + + + + + diff --git a/metronome/networking/test/src/io/iohk/metronome/networking/ConnectionHandlerSpec.scala b/metronome/networking/test/src/io/iohk/metronome/networking/ConnectionHandlerSpec.scala new file mode 100644 index 00000000..5c8a7588 --- /dev/null +++ b/metronome/networking/test/src/io/iohk/metronome/networking/ConnectionHandlerSpec.scala @@ -0,0 +1,343 @@ +package io.iohk.metronome.networking + +import cats.effect.Resource +import cats.effect.concurrent.{Deferred, Ref} +import io.iohk.metronome.crypto.ECPublicKey +import io.iohk.metronome.networking.ConnectionHandler.{ + ConnectionAlreadyClosedException, + FinishedConnection +} +import io.iohk.metronome.networking.ConnectionHandlerSpec.{ + buildHandlerResource, + buildNConnections, + _ +} +import io.iohk.metronome.networking.EncryptedConnectionProvider.DecodingError +import io.iohk.metronome.networking.MockEncryptedConnectionProvider.MockEncryptedConnection +import io.iohk.metronome.networking.RemoteConnectionManagerTestUtils._ +import monix.eval.Task +import monix.execution.Scheduler +import org.scalatest.flatspec.AsyncFlatSpecLike +import org.scalatest.matchers.should.Matchers +import io.iohk.metronome.tracer.Tracer + +import java.net.InetSocketAddress +import scala.concurrent.duration._ + +class ConnectionHandlerSpec extends AsyncFlatSpecLike with Matchers { + implicit val testScheduler = + Scheduler.fixedPool("ConnectionHandlerSpec", 16) + implicit val timeOut = 5.seconds + + behavior of "ConnectionHandler" + + it should "register new connections" in customTestCaseResourceT( + buildHandlerResource() + ) { handler => + for { + newConnection <- MockEncryptedConnection() + _ <- handler.registerOutgoing(newConnection) + connections <- handler.getAllActiveConnections + } yield { + assert(connections.contains(newConnection.key)) + } + } + + it should "send message to registered connection" in customTestCaseResourceT( + buildHandlerResource() + ) { handler => + for { + newConnection <- MockEncryptedConnection() + _ <- handler.registerOutgoing(newConnection) + connections <- handler.getAllActiveConnections + sendResult <- handler.sendMessage(newConnection.key, MessageA(1)) + } yield { + assert(connections.contains(newConnection.key)) + assert(sendResult.isRight) + } + } + + it should "fail to send message to un-registered connection" in customTestCaseResourceT( + buildHandlerResource() + ) { handler => + for { + newConnection <- MockEncryptedConnection() + connections <- handler.getAllActiveConnections + sendResult <- handler.sendMessage(newConnection.key, MessageA(1)) + } yield { + assert(connections.isEmpty) + assert(sendResult.isLeft) + assert( + sendResult.left.getOrElse(null) == ConnectionAlreadyClosedException( + newConnection.key + ) + ) + } + } + + it should "fail to send message silently failed peer" in customTestCaseResourceT( + buildHandlerResource() + ) { handler => + for { + newConnection <- MockEncryptedConnection() + _ <- newConnection.closeRemoteWithoutInfo + _ <- handler.registerOutgoing(newConnection) + connections <- handler.getAllActiveConnections + sendResult <- handler.sendMessage(newConnection.key, MessageA(1)) + } yield { + assert(connections.contains(newConnection.key)) + assert(sendResult.isLeft) + assert( + sendResult.left.getOrElse(null) == ConnectionAlreadyClosedException( + newConnection.key + ) + ) + } + } + + it should "not register and close duplicated outgoing connection" in customTestCaseResourceT( + buildHandlerResource() + ) { handler => + for { + initialConnection <- MockEncryptedConnection() + duplicatedConnection <- MockEncryptedConnection( + (initialConnection.key, initialConnection.address) + ) + _ <- handler.registerOutgoing(initialConnection) + connections <- handler.getAllActiveConnections + _ <- handler.registerOutgoing(duplicatedConnection) + connectionsAfterDuplication <- handler.getAllActiveConnections + _ <- duplicatedConnection.isClosed.waitFor(closed => closed) + duplicatedClosed <- duplicatedConnection.isClosed + initialClosed <- initialConnection.isClosed + } yield { + assert(connections.contains(initialConnection.key)) + assert(connectionsAfterDuplication.contains(initialConnection.key)) + assert(duplicatedClosed) + assert(!initialClosed) + } + } + + it should "replace incoming connections" in customTestCaseResourceT( + buildHandlerResourceWithCallbackCounter + ) { case (handler, counter) => + for { + initialConnection <- MockEncryptedConnection() + duplicatedConnection <- MockEncryptedConnection( + (initialConnection.key, initialConnection.address) + ) + _ <- handler.registerIncoming(fakeLocalAddress, initialConnection) + connections <- handler.getAllActiveConnections + _ <- handler.registerIncoming(fakeLocalAddress, duplicatedConnection) + _ <- initialConnection.isClosed.waitFor(closed => closed) + connectionsAfterDuplication <- handler.getAllActiveConnections + initialClosed <- initialConnection.isClosed + duplicatedClosed <- duplicatedConnection.isClosed + numberOfCalledCallbacks <- counter.get + } yield { + assert(connections.contains(initialConnection.key)) + assert(connectionsAfterDuplication.contains(initialConnection.key)) + assert(initialClosed) + assert(!duplicatedClosed) + assert(numberOfCalledCallbacks == 0) + } + } + + it should "treat last conflicting incoming connection as live one" in customTestCaseResourceT( + buildHandlerResourceWithCallbackCounter + ) { case (handler, counter) => + val numberOfConflictingConnections = 4 + + for { + initialConnection <- MockEncryptedConnection() + duplicatedConnections <- Task.traverse( + (0 until numberOfConflictingConnections).toList + )(_ => + MockEncryptedConnection( + (initialConnection.key, initialConnection.address) + ) + ) + + _ <- handler.registerIncoming(fakeLocalAddress, initialConnection) + connections <- handler.getAllActiveConnections + (closed, last) = ( + duplicatedConnections.dropRight(1), + duplicatedConnections.last + ) + _ <- Task.traverse(duplicatedConnections)(duplicated => + handler.registerIncoming(fakeLocalAddress, duplicated) + ) + allDuplicatesClosed <- Task + .sequence(closed.map(connection => connection.isClosed)) + .map(statusList => statusList.forall(closed => closed)) + .waitFor(allClosed => allClosed) + lastClosed <- last.isClosed + numberOfCalledCallbacks <- counter.get + activeConnectionsAfterConflicts <- handler.getAllActiveConnections + } yield { + assert(connections.contains(initialConnection.key)) + assert(allDuplicatesClosed) + assert(!lastClosed) + assert(numberOfCalledCallbacks == 0) + assert(activeConnectionsAfterConflicts.size == 1) + + } + } + + it should "close all connections in background when released" in customTestCaseT { + val expectedNumberOfConnections = 4 + for { + handlerAndRelease <- buildHandlerResource().allocated + (handler, release) = handlerAndRelease + connections <- buildNConnections(expectedNumberOfConnections) + _ <- Task.traverse(connections)(connection => + handler.registerOutgoing(connection) + ) + maxNumberOfActiveConnections <- handler.numberOfActiveConnections + .waitFor(numOfConnections => + numOfConnections == expectedNumberOfConnections + ) + + _ <- release + connectionsAfterClose <- handler.getAllActiveConnections.waitFor( + connections => connections.isEmpty + ) + } yield { + assert(maxNumberOfActiveConnections == expectedNumberOfConnections) + assert(connectionsAfterClose.isEmpty) + } + } + + it should "call provided callback when connection is closed" in customTestCaseT { + for { + cb <- Deferred.tryable[Task, Unit] + handlerAndRelease <- buildHandlerResource(_ => cb.complete(())).allocated + (handler, release) = handlerAndRelease + newConnection <- MockEncryptedConnection() + _ <- handler.registerOutgoing(newConnection) + numberOfActive <- handler.numberOfActiveConnections.waitFor(_ == 1) + _ <- newConnection.pushRemoteEvent(None) + numberOfActiveAfterDisconnect <- handler.numberOfActiveConnections + .waitFor(_ == 0) + callbackCompleted <- cb.tryGet.waitFor(_.isDefined) + _ <- release + } yield { + assert(numberOfActive == 1) + assert(numberOfActiveAfterDisconnect == 0) + assert(callbackCompleted.isDefined) + } + } + + it should "call provided callback and close connection in case of error" in customTestCaseT { + for { + cb <- Deferred.tryable[Task, Unit] + handlerAndRelease <- buildHandlerResource(_ => cb.complete(())).allocated + (handler, release) = handlerAndRelease + newConnection <- MockEncryptedConnection() + _ <- handler.registerOutgoing(newConnection) + numberOfActive <- handler.numberOfActiveConnections.waitFor(_ == 1) + _ <- newConnection.pushRemoteEvent(Some(Left(DecodingError))) + numberOfActiveAfterError <- handler.numberOfActiveConnections + .waitFor(_ == 0) + callbackCompleted <- cb.tryGet.waitFor(_.isDefined) + _ <- release + } yield { + assert(numberOfActive == 1) + assert(numberOfActiveAfterError == 0) + assert(callbackCompleted.isDefined) + } + } + + it should "try not to call callback in case of closing manager" in customTestCaseT { + for { + cb <- Deferred.tryable[Task, Unit] + handlerAndRelease <- buildHandlerResource(_ => cb.complete(())).allocated + (handler, release) = handlerAndRelease + newConnection <- MockEncryptedConnection() + _ <- handler.registerOutgoing(newConnection) + numberOfActive <- handler.numberOfActiveConnections.waitFor(_ == 1) + _ <- release + numberOfActiveAfterDisconnect <- handler.numberOfActiveConnections + .waitFor(_ == 0) + callbackCompleted <- cb.tryGet.waitFor(_.isDefined).attempt + } yield { + assert(numberOfActive == 1) + assert(numberOfActiveAfterDisconnect == 0) + assert(callbackCompleted.isLeft) + } + } + + it should "multiplex messages from all open channels" in customTestCaseResourceT( + buildHandlerResource() + ) { handler => + val expectedNumberOfConnections = 4 + for { + connections <- buildNConnections(expectedNumberOfConnections) + _ <- Task.traverse(connections)(connection => + handler.registerOutgoing(connection) + ) + maxNumberOfActiveConnections <- handler.numberOfActiveConnections + .waitFor(numOfConnections => + numOfConnections == expectedNumberOfConnections + ) + _ <- Task.traverse(connections) { encConnection => + encConnection.pushRemoteEvent(Some(Right(MessageA(1)))) + } + receivedMessages <- handler.incomingMessages + .take(expectedNumberOfConnections) + .toListL + } yield { + + val senders = connections.map(_.key).toSet + val receivedFrom = receivedMessages.map(_.from).toSet + assert(receivedMessages.size == expectedNumberOfConnections) + assert(maxNumberOfActiveConnections == expectedNumberOfConnections) + assert( + senders.intersect(receivedFrom).size == expectedNumberOfConnections + ) + } + } + +} + +object ConnectionHandlerSpec { + val fakeLocalAddress = new InetSocketAddress("localhost", 9081) + + implicit class TaskOps[A](task: Task[A]) { + def waitFor(condition: A => Boolean)(implicit timeOut: FiniteDuration) = { + task.restartUntil(condition).timeout(timeOut) + } + } + + implicit val tracers: NetworkTracers[Task, ECPublicKey, TestMessage] = + NetworkTracers(Tracer.noOpTracer) + + def buildHandlerResource( + cb: FinishedConnection[ECPublicKey] => Task[Unit] = _ => Task(()) + ): Resource[Task, ConnectionHandler[Task, ECPublicKey, TestMessage]] = { + ConnectionHandler + .apply[Task, ECPublicKey, TestMessage]( + cb, + oppositeConnectionOverlap = Duration.Zero + ) + } + + def buildHandlerResourceWithCallbackCounter: Resource[ + Task, + (ConnectionHandler[Task, ECPublicKey, TestMessage], Ref[Task, Long]) + ] = { + for { + counter <- Resource.liftF(Ref.of[Task, Long](0L)) + handler <- buildHandlerResource(_ => + counter.update(current => current + 1) + ) + } yield (handler, counter) + } + + def buildNConnections(n: Int)(implicit + s: Scheduler + ): Task[List[MockEncryptedConnection]] = { + Task.traverse((0 until n).toList)(_ => MockEncryptedConnection()) + } + +} diff --git a/metronome/networking/test/src/io/iohk/metronome/networking/MockEncryptedConnectionProvider.scala b/metronome/networking/test/src/io/iohk/metronome/networking/MockEncryptedConnectionProvider.scala new file mode 100644 index 00000000..d74f9975 --- /dev/null +++ b/metronome/networking/test/src/io/iohk/metronome/networking/MockEncryptedConnectionProvider.scala @@ -0,0 +1,273 @@ +package io.iohk.metronome.networking + +import cats.effect.concurrent.{Deferred, Ref, TryableDeferred} +import cats.implicits.toFlatMapOps +import io.iohk.metronome.crypto.ECPublicKey +import io.iohk.metronome.networking.EncryptedConnectionProvider.ConnectionAlreadyClosed +import io.iohk.metronome.networking.MockEncryptedConnectionProvider._ +import io.iohk.metronome.networking.RemoteConnectionManagerTestUtils.{ + TestMessage, + getFakeRandomKey +} +import io.iohk.metronome.networking.RemoteConnectionManagerWithMockProviderSpec.fakeLocalAddress +import monix.catnap.ConcurrentQueue +import monix.eval.Task + +import java.net.InetSocketAddress + +class MockEncryptedConnectionProvider( + private val incomingConnections: ConcurrentQueue[Task, IncomingServerEvent], + private val onlineConnections: Ref[ + Task, + Map[ECPublicKey, MockEncryptedConnection] + ], + private val connectionStatistics: ConnectionStatisticsHolder, + val localPeerInfo: (ECPublicKey, InetSocketAddress) = + (getFakeRandomKey(), fakeLocalAddress) +) extends EncryptedConnectionProvider[Task, ECPublicKey, TestMessage] { + + private def connect(k: ECPublicKey) = { + onlineConnections.get.flatMap { state => + state.get(k) match { + case Some(value) => Task.now(value) + case None => + Task.raiseError(new RuntimeException("Failed connections")) + } + } + } + + override def connectTo( + k: ECPublicKey, + address: InetSocketAddress + ): Task[MockEncryptedConnection] = { + (for { + _ <- connectionStatistics.incrementInFlight(k) + connection <- connect(k) + } yield connection).doOnFinish(_ => connectionStatistics.decrementInFlight) + } + + override def incomingConnection: Task[IncomingServerEvent] = + incomingConnections.poll +} + +object MockEncryptedConnectionProvider { + def apply(): Task[MockEncryptedConnectionProvider] = { + for { + queue <- ConcurrentQueue.unbounded[Task, IncomingServerEvent]() + connections <- Ref.of[Task, Map[ECPublicKey, MockEncryptedConnection]]( + Map.empty + ) + connectionsStatistics <- Ref.of[Task, ConnectionStatistics]( + ConnectionStatistics(0, 0, Map.empty) + ) + } yield new MockEncryptedConnectionProvider( + queue, + connections, + new ConnectionStatisticsHolder(connectionsStatistics) + ) + } + + implicit class MockEncryptedConnectionProviderTestMethodsOps( + provider: MockEncryptedConnectionProvider + ) { + + private def disconnect( + withFailure: Boolean, + chosenPeer: Option[ECPublicKey] = None + ): Task[MockEncryptedConnection] = { + provider.onlineConnections + .modify { current => + chosenPeer.fold { + val peer = current.head + (current - peer._1, peer._2) + } { keyToFail => + val peer = current(keyToFail) + (current - keyToFail, peer) + } + } + .flatTap { connection => + if (withFailure) { + connection.closeRemoteWithoutInfo + } else { + connection.close + } + } + } + + def randomPeerDisconnect(): Task[MockEncryptedConnection] = { + disconnect(withFailure = false) + } + + def specificPeerDisconnect( + key: ECPublicKey + ): Task[MockEncryptedConnection] = { + disconnect(withFailure = false, Some(key)) + } + + def failRandomPeer(): Task[MockEncryptedConnection] = { + disconnect(withFailure = true) + } + + def registerOnlinePeer(key: ECPublicKey): Task[MockEncryptedConnection] = { + for { + connection <- MockEncryptedConnection((key, fakeLocalAddress)) + _ <- provider.onlineConnections.update { connections => + connections.updated( + key, + connection + ) + } + } yield connection + } + + def getAllRegisteredPeers: Task[Set[MockEncryptedConnection]] = { + provider.onlineConnections.get.map(connections => + connections.values.toSet + ) + } + + def newIncomingPeer(key: ECPublicKey): Task[MockEncryptedConnection] = { + registerOnlinePeer(key).flatMap { connection => + provider.incomingConnections + .offer(Some(Right(connection))) + .map(_ => connection) + } + } + + def getReceivedMessagesPerPeer + : Task[Set[(ECPublicKey, List[TestMessage])]] = { + provider.onlineConnections.get.flatMap { connections => + Task.traverse(connections.toSet) { case (key, connection) => + connection.getReceivedMessages.map(received => (key, received)) + } + } + } + + def getStatistics: Task[ConnectionStatistics] = + provider.connectionStatistics.stats.get + + } + + case class ConnectionStatistics( + inFlightConnections: Long, + maxInFlightConnections: Long, + connectionCounts: Map[ECPublicKey, Long] + ) + + class ConnectionStatisticsHolder(val stats: Ref[Task, ConnectionStatistics]) { + def incrementInFlight(connectionTo: ECPublicKey): Task[Unit] = { + stats.update { current => + val newInFlight = current.inFlightConnections + 1 + val newMax = + if (newInFlight > current.maxInFlightConnections) newInFlight + else current.maxInFlightConnections + + val newPerConnectionStats = + current.connectionCounts.get(connectionTo) match { + case Some(value) => + current.connectionCounts.updated(connectionTo, value + 1L) + case None => current.connectionCounts.updated(connectionTo, 0L) + } + + ConnectionStatistics(newInFlight, newMax, newPerConnectionStats) + } + } + + def decrementInFlight: Task[Unit] = { + stats.update(current => + current.copy(inFlightConnections = current.inFlightConnections - 1) + ) + } + } + + type IncomingServerEvent = Option[Either[ + EncryptedConnectionProvider.HandshakeFailed, + EncryptedConnection[Task, ECPublicKey, TestMessage] + ]] + + type IncomingConnectionEvent = + Option[Either[EncryptedConnectionProvider.ConnectionError, TestMessage]] + + class MockEncryptedConnection( + private val incomingEvents: ConcurrentQueue[ + Task, + IncomingConnectionEvent + ], + private val closeToken: TryableDeferred[Task, Unit], + private val sentMessages: Ref[Task, List[TestMessage]], + val remotePeerInfo: (ECPublicKey, InetSocketAddress) = + (getFakeRandomKey(), fakeLocalAddress), + val localAddress: InetSocketAddress = fakeLocalAddress + ) extends EncryptedConnection[Task, ECPublicKey, TestMessage] { + + override def close: Task[Unit] = { + Task + .parZip2(incomingEvents.offer(None), closeToken.complete(()).attempt) + .void + } + + override def incomingMessage: Task[IncomingConnectionEvent] = + incomingEvents.poll + + override def sendMessage(m: TestMessage): Task[Unit] = + closeToken.tryGet.flatMap { + case Some(_) => + Task.raiseError(ConnectionAlreadyClosed(remotePeerInfo._2)) + case None => + Task + .race(closeToken.get, sentMessages.update(current => m :: current)) + .flatMap { + case Left(_) => + Task.raiseError(ConnectionAlreadyClosed(remotePeerInfo._2)) + case Right(_) => Task.now(()) + } + } + } + + object MockEncryptedConnection { + def apply( + remotePeerInfo: (ECPublicKey, InetSocketAddress) = + (getFakeRandomKey(), fakeLocalAddress) + ): Task[MockEncryptedConnection] = { + for { + incomingEvents <- ConcurrentQueue + .unbounded[Task, IncomingConnectionEvent]() + closeToken <- Deferred.tryable[Task, Unit] + sentMessages <- Ref.of[Task, List[TestMessage]](List.empty[TestMessage]) + } yield new MockEncryptedConnection( + incomingEvents, + closeToken, + sentMessages, + remotePeerInfo + ) + } + + implicit class MockEncryptedConnectionTestMethodsOps( + connection: MockEncryptedConnection + ) { + lazy val key = connection.remotePeerInfo._1 + + lazy val address = connection.remotePeerInfo._2 + + def pushRemoteEvent( + ev: Option[ + Either[EncryptedConnectionProvider.ConnectionError, TestMessage] + ] + ): Task[Unit] = { + connection.incomingEvents.offer(ev) + } + + def getReceivedMessages: Task[List[TestMessage]] = + connection.sentMessages.get + + // it is possible that in some cases remote peer will be closed without generating final None event in incoming events + // queue + def closeRemoteWithoutInfo: Task[Unit] = + connection.closeToken.complete(()) + + def isClosed: Task[Boolean] = + connection.closeToken.tryGet.map(closed => closed.isDefined) + } + } + +} diff --git a/metronome/networking/test/src/io/iohk/metronome/networking/NetworkSpec.scala b/metronome/networking/test/src/io/iohk/metronome/networking/NetworkSpec.scala new file mode 100644 index 00000000..dbf04bdd --- /dev/null +++ b/metronome/networking/test/src/io/iohk/metronome/networking/NetworkSpec.scala @@ -0,0 +1,98 @@ +package io.iohk.metronome.networking + +import cats.effect.Resource +import cats.effect.concurrent.Ref +import monix.eval.Task +import monix.tail.Iterant +import monix.execution.Scheduler.Implicits.global +import org.scalatest.flatspec.AsyncFlatSpec +import org.scalatest.matchers.should.Matchers +import io.iohk.metronome.networking.ConnectionHandler.MessageReceived + +class NetworkSpec extends AsyncFlatSpec with Matchers { + + sealed trait TestMessage + case class TestFoo(foo: String) extends TestMessage + case class TestBar(bar: Int) extends TestMessage + + type TestKey = String + + type TestKeyAndMessage = (TestKey, TestMessage) + type TestMessageReceived = MessageReceived[TestKey, TestMessage] + + class TestNetwork( + outbox: Vector[TestKeyAndMessage], + val inbox: Ref[Task, Vector[ + MessageReceived[TestKey, TestMessage] + ]] + ) extends Network[Task, TestKey, TestMessage] { + + override def incomingMessages: Iterant[Task, TestMessageReceived] = + Iterant.fromIndexedSeq { + outbox.map { case (sender, message) => + MessageReceived(sender, message) + } + } + + override def sendMessage( + recipient: TestKey, + message: TestMessage + ): Task[Unit] = + inbox.update(_ :+ MessageReceived(recipient, message)) + } + + object TestNetwork { + def apply(outbox: Vector[TestKeyAndMessage]) = + Ref + .of[Task, Vector[TestMessageReceived]](Vector.empty) + .map(new TestNetwork(outbox, _)) + } + + behavior of "splitter" + + it should "split and merge messages" in { + val messages = Vector( + "Alice" -> TestFoo("spam"), + "Bob" -> TestBar(42), + "Charlie" -> TestFoo("eggs") + ) + val resources = for { + network <- Resource.liftF(TestNetwork(messages)) + (fooNetwork, barNetwork) <- Network + .splitter[Task, TestKey, TestMessage, String, Int](network)( + split = { + case TestFoo(msg) => Left(msg) + case TestBar(msg) => Right(msg) + }, + merge = { + case Left(msg) => TestFoo(msg) + case Right(msg) => TestBar(msg) + } + ) + } yield (network, fooNetwork, barNetwork) + + val test = resources.use { case (network, fooNetwork, barNetwork) => + for { + fms <- fooNetwork.incomingMessages.take(2).toListL + bms <- barNetwork.incomingMessages.take(1).toListL + _ <- barNetwork.sendMessage("Dave", 123) + _ <- fooNetwork.sendMessage("Eve", "Adam") + _ <- barNetwork.sendMessage("Fred", 456) + nms <- network.inbox.get + } yield { + fms shouldBe List( + MessageReceived("Alice", "spam"), + MessageReceived("Charlie", "eggs") + ) + bms shouldBe List(MessageReceived("Bob", 42)) + nms shouldBe List( + MessageReceived("Dave", TestBar(123)), + MessageReceived("Eve", TestFoo("Adam")), + MessageReceived("Fred", TestBar(456)) + ) + } + } + + test.runToFuture + } +} diff --git a/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerTestUtils.scala b/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerTestUtils.scala new file mode 100644 index 00000000..258464e8 --- /dev/null +++ b/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerTestUtils.scala @@ -0,0 +1,66 @@ +package io.iohk.metronome.networking + +import cats.effect.Resource +import io.iohk.metronome.crypto.{ECKeyPair, ECPublicKey} + +import java.net.{InetSocketAddress, ServerSocket} +import java.security.SecureRandom +import monix.eval.Task +import monix.execution.Scheduler +import org.scalatest.Assertion + +import scala.concurrent.Future +import scala.util.Random +import scodec.bits.ByteVector +import scodec.Codec + +object RemoteConnectionManagerTestUtils { + def customTestCaseResourceT[T]( + fixture: Resource[Task, T] + )(theTest: T => Task[Assertion])(implicit s: Scheduler): Future[Assertion] = { + fixture.use(fix => theTest(fix)).runToFuture + } + + def customTestCaseT[T]( + test: => Task[Assertion] + )(implicit s: Scheduler): Future[Assertion] = { + test.runToFuture + } + + def randomAddress(): InetSocketAddress = { + val s = new ServerSocket(0) + try { + new InetSocketAddress("localhost", s.getLocalPort) + } finally { + s.close() + } + } + + import scodec.codecs._ + + sealed abstract class TestMessage + case class MessageA(i: Int) extends TestMessage + case class MessageB(s: String) extends TestMessage + + object TestMessage { + implicit val messageCodec: Codec[TestMessage] = discriminated[TestMessage] + .by(uint8) + .typecase(1, int32.as[MessageA]) + .typecase(2, utf8.as[MessageB]) + } + + def getFakeRandomKey(): ECPublicKey = { + val array = new Array[Byte](ECPublicKey.Length) + Random.nextBytes(array) + ECPublicKey(ByteVector(array)) + } + + case class NodeInfo(keyPair: ECKeyPair) + + object NodeInfo { + def generateRandom(secureRandom: SecureRandom): NodeInfo = { + val keyPair = ECKeyPair.generate(secureRandom) + NodeInfo(keyPair) + } + } +} diff --git a/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerWithMockProviderSpec.scala b/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerWithMockProviderSpec.scala new file mode 100644 index 00000000..00fb2d16 --- /dev/null +++ b/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerWithMockProviderSpec.scala @@ -0,0 +1,366 @@ +package io.iohk.metronome.networking + +import cats.effect.Resource +import io.iohk.metronome.crypto.ECPublicKey +import io.iohk.metronome.networking.ConnectionHandler.ConnectionAlreadyClosedException +import io.iohk.metronome.networking.EncryptedConnectionProvider.DecodingError +import io.iohk.metronome.networking.MockEncryptedConnectionProvider._ +import io.iohk.metronome.networking.RemoteConnectionManager.RetryConfig.RandomJitterConfig +import io.iohk.metronome.networking.RemoteConnectionManager.{ + ClusterConfig, + RetryConfig +} +import io.iohk.metronome.networking.RemoteConnectionManagerTestUtils._ +import io.iohk.metronome.networking.RemoteConnectionManagerWithMockProviderSpec.{ + RemoteConnectionManagerOps, + buildConnectionsManagerWithMockProvider, + buildTestCaseWithNPeers, + defaultToMake, + fakeLocalAddress, + longRetryConfig +} +import io.iohk.metronome.tracer.Tracer +import monix.eval.Task +import monix.execution.Scheduler +import org.scalatest.flatspec.AsyncFlatSpecLike +import org.scalatest.matchers.should.Matchers + +import java.net.InetSocketAddress +import scala.concurrent.duration._ + +class RemoteConnectionManagerWithMockProviderSpec + extends AsyncFlatSpecLike + with Matchers { + implicit val testScheduler = + Scheduler.fixedPool("RemoteConnectionManagerUtSpec", 16) + implicit val timeOut = 5.seconds + + behavior of "RemoteConnectionManagerWithMockProvider" + + it should "continue to make connections to unresponsive peer with exponential backoff" in customTestCaseT { + MockEncryptedConnectionProvider().flatMap(provider => + buildConnectionsManagerWithMockProvider(provider) + .use { connectionManager => + for { + _ <- Task.sleep(1.second) + stats <- provider.getStatistics + acquiredConnections <- connectionManager.getAcquiredConnections + } yield { + assert(stats.maxInFlightConnections == 1) + assert( + stats.connectionCounts + .get(defaultToMake) + .exists(count => count == 2 || count == 3) + ) + assert(acquiredConnections.isEmpty) + } + } + ) + } + + it should "continue to make connections to unresponsive peers one connection at the time" in customTestCaseT { + val connectionToMake = + (0 to 3).map(_ => (getFakeRandomKey(), fakeLocalAddress)).toSet + MockEncryptedConnectionProvider().flatMap(provider => + buildConnectionsManagerWithMockProvider( + provider, + nodesInCluster = connectionToMake + ) + .use { connectionManager => + for { + _ <- Task.sleep(1.second) + stats <- provider.getStatistics + acquiredConnections <- connectionManager.getAcquiredConnections + } yield { + assert( + connectionToMake.forall(connection => + stats.connectionCounts + .get(connection._1) + .exists(count => count == 2 || count == 3) + ) + ) + assert(stats.maxInFlightConnections == 1) + assert(acquiredConnections.isEmpty) + } + } + ) + } + + it should "connect to online peers" in customTestCaseResourceT( + buildTestCaseWithNPeers(4) + ) { case (provider, manager, _) => + for { + stats <- provider.getStatistics + acquiredConnections <- manager.getAcquiredConnections + } yield { + assert(stats.maxInFlightConnections == 1) + assert(acquiredConnections.size == 4) + } + } + + it should "send messages to online peers" in customTestCaseResourceT( + buildTestCaseWithNPeers(4) + ) { case (provider, manager, _) => + for { + acquiredConnections <- manager.getAcquiredConnections + _ <- manager.getAcquiredConnections.flatMap(keys => + Task.traverse(keys)(key => manager.sendMessage(key, MessageA(2))) + ) + received <- provider.getReceivedMessagesPerPeer.map(_.map(_._2)) + stats <- provider.getStatistics + } yield { + assert(stats.maxInFlightConnections == 1) + assert(acquiredConnections.size == 4) + assert( + received.forall(peerMessages => peerMessages.contains(MessageA(2))) + ) + } + } + + it should "try to reconnect disconnected peer" in customTestCaseResourceT( + buildTestCaseWithNPeers(2) + ) { case (provider, manager, _) => + for { + disconnectedPeer <- provider.randomPeerDisconnect() + _ <- manager.waitForNConnections(1) + notContainDisconnectedPeer <- manager.notContainsConnection( + disconnectedPeer + ) + _ <- provider.registerOnlinePeer(disconnectedPeer.key) + _ <- manager.waitForNConnections(2) + containsAfterReconnect <- manager.containsConnection(disconnectedPeer) + } yield { + assert(notContainDisconnectedPeer) + assert(containsAfterReconnect) + } + } + + it should "try to reconnect to failed peer after failed send" in customTestCaseResourceT( + buildTestCaseWithNPeers(2) + ) { case (provider, manager, _) => + for { + disconnectedPeer <- provider.failRandomPeer() + _ <- Task.sleep(100.milliseconds) + // remote peer failed without any notice, we still have it in our acquired connections + containsFailedPeer <- manager.containsConnection(disconnectedPeer) + sendResult <- manager + .sendMessage(disconnectedPeer.key, MessageA(1)) + .map(result => result.left.getOrElse(null)) + _ <- Task( + assert( + sendResult == ConnectionAlreadyClosedException(disconnectedPeer.key) + ) + ) + notContainsFailedPeerAfterSend <- manager.notContainsConnection( + disconnectedPeer + ) + _ <- provider.registerOnlinePeer(disconnectedPeer.key) + _ <- manager.waitForNConnections(2) + containsFailedAfterReconnect <- manager.containsConnection( + disconnectedPeer + ) + } yield { + assert(containsFailedPeer) + assert(notContainsFailedPeerAfterSend) + assert(containsFailedAfterReconnect) + } + } + + it should "fail sending message to unknown peer" in customTestCaseResourceT( + buildTestCaseWithNPeers(2) + ) { case (_, manager, _) => + val randomKey = getFakeRandomKey() + for { + sendResult <- manager.sendMessage(randomKey, MessageA(1)) + } yield { + assert(sendResult.isLeft) + assert( + sendResult.left.getOrElse(null) == ConnectionAlreadyClosedException( + randomKey + ) + ) + } + } + + it should "deny not allowed incoming connections " in customTestCaseResourceT( + buildTestCaseWithNPeers(2) + ) { case (provider, manager, _) => + for { + incomingPeerConnection <- provider.newIncomingPeer( + getFakeRandomKey() + ) + _ <- Task.sleep(100.milliseconds) + notContainsNotAllowedIncoming <- manager.notContainsConnection( + incomingPeerConnection + ) + closedIncoming <- incomingPeerConnection.isClosed + } yield { + assert(notContainsNotAllowedIncoming) + assert(closedIncoming) + } + } + + it should "allow configured incoming connections" in customTestCaseResourceT( + buildTestCaseWithNPeers(2, shouldBeOnline = false, longRetryConfig) + ) { case (provider, manager, clusterPeers) => + for { + initialAcquired <- manager.getAcquiredConnections + incomingConnection <- provider.newIncomingPeer(clusterPeers.head) + _ <- manager.waitForNConnections(1) + containsIncoming <- manager.containsConnection(incomingConnection) + } yield { + assert(initialAcquired.isEmpty) + assert(containsIncoming) + } + } + + it should "prefer most fresh incoming connection" in customTestCaseResourceT( + buildTestCaseWithNPeers(2, shouldBeOnline = false, longRetryConfig) + ) { case (provider, manager, clusterPeers) => + for { + initialAcquired <- manager.getAcquiredConnections + firstIncoming <- provider.newIncomingPeer(clusterPeers.head) + _ <- manager.waitForNConnections(1) + containsIncoming <- manager.containsConnection(firstIncoming) + duplicatedIncoming <- provider.newIncomingPeer(clusterPeers.head) + _ <- Task.sleep(500.millis) // Let the offered connection be processed. + duplicatedIncomingClosed <- duplicatedIncoming.isClosed + firstIncomingClosed <- firstIncoming.isClosed + } yield { + assert(initialAcquired.isEmpty) + assert(containsIncoming) + assert(!duplicatedIncomingClosed) + assert(firstIncomingClosed) + } + } + + it should "disconnect from peer on which connection error happened" in customTestCaseResourceT( + buildTestCaseWithNPeers(2) + ) { case (provider, manager, _) => + for { + initialAcquired <- manager.getAcquiredConnections + randomAcquiredConnection <- provider.getAllRegisteredPeers.map(_.head) + _ <- randomAcquiredConnection.pushRemoteEvent(Some(Left(DecodingError))) + _ <- manager.waitForNConnections(1) + _ <- Task.sleep(500.millis) // Let the offered connection be processed. + errorIsClosed <- randomAcquiredConnection.isClosed + } yield { + assert(initialAcquired.size == 2) + assert(errorIsClosed) + } + } + + it should "receive messages from all connections" in customTestCaseResourceT( + buildTestCaseWithNPeers(2) + ) { case (provider, manager, _) => + for { + acquiredConnections <- manager.getAcquiredConnections + connections <- provider.getAllRegisteredPeers + _ <- Task.traverse(connections)(conn => + conn.pushRemoteEvent(Some(Right(MessageA(1)))) + ) + received <- manager.incomingMessages.take(2).toListL + } yield { + assert(acquiredConnections.size == 2) + assert(received.size == 2) + } + } + +} + +object RemoteConnectionManagerWithMockProviderSpec { + implicit class RemoteConnectionManagerOps( + manager: RemoteConnectionManager[Task, ECPublicKey, TestMessage] + ) { + def waitForNConnections( + n: Int + )(implicit timeOut: FiniteDuration): Task[Unit] = { + manager.getAcquiredConnections + .restartUntil(connections => connections.size == n) + .timeout(timeOut) + .void + } + + def containsConnection( + connection: MockEncryptedConnection + ): Task[Boolean] = { + manager.getAcquiredConnections.map(connections => + connections.contains(connection.remotePeerInfo._1) + ) + } + + def notContainsConnection( + connection: MockEncryptedConnection + ): Task[Boolean] = { + containsConnection(connection).map(contains => !contains) + } + } + + val noJitterConfig = RandomJitterConfig.buildJitterConfig(0).get + val quickRetryConfig = + RetryConfig(50.milliseconds, 2, 2.seconds, noJitterConfig, Duration.Zero) + val longRetryConfig: RetryConfig = + RetryConfig(5.seconds, 2, 20.seconds, noJitterConfig, Duration.Zero) + + def buildTestCaseWithNPeers( + n: Int, + shouldBeOnline: Boolean = true, + retryConfig: RetryConfig = quickRetryConfig + )(implicit timeOut: FiniteDuration): Resource[ + Task, + ( + MockEncryptedConnectionProvider, + RemoteConnectionManager[Task, ECPublicKey, TestMessage], + Set[ECPublicKey] + ) + ] = { + val keys = (0 until n).map(_ => getFakeRandomKey()).toSet + + for { + provider <- Resource.liftF(MockEncryptedConnectionProvider()) + _ <- Resource.liftF { + if (shouldBeOnline) { + Task.traverse(keys)(key => provider.registerOnlinePeer(key)) + } else { + Task.unit + } + } + manager <- buildConnectionsManagerWithMockProvider( + provider, + retryConfig = retryConfig, + nodesInCluster = keys.map(key => (key, fakeLocalAddress)) + ) + _ <- Resource.liftF { + if (shouldBeOnline) { + manager.waitForNConnections(n) + } else { + Task.unit + } + } + } yield (provider, manager, keys) + } + + val fakeLocalAddress = new InetSocketAddress("localhost", 127) + + val defalutAllowed = getFakeRandomKey() + val defaultToMake = getFakeRandomKey() + + implicit val tracers: NetworkTracers[Task, ECPublicKey, TestMessage] = + NetworkTracers(Tracer.noOpTracer) + + def buildConnectionsManagerWithMockProvider( + ec: MockEncryptedConnectionProvider, + retryConfig: RetryConfig = quickRetryConfig, + nodesInCluster: Set[(ECPublicKey, InetSocketAddress)] = Set( + (defaultToMake, fakeLocalAddress) + ) + ): Resource[ + Task, + RemoteConnectionManager[Task, ECPublicKey, TestMessage] + ] = { + val clusterConfig = ClusterConfig(nodesInCluster) + + RemoteConnectionManager(ec, clusterConfig, retryConfig) + } + +} diff --git a/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerWithScalanetProviderSpec.scala b/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerWithScalanetProviderSpec.scala new file mode 100644 index 00000000..eb7bb72f --- /dev/null +++ b/metronome/networking/test/src/io/iohk/metronome/networking/RemoteConnectionManagerWithScalanetProviderSpec.scala @@ -0,0 +1,383 @@ +package io.iohk.metronome.networking + +import cats.data.NonEmptyList +import cats.effect.concurrent.Ref +import cats.effect.{Concurrent, ContextShift, Resource, Sync, Timer} +import io.circe.{Encoder, Json, JsonObject} +import io.iohk.metronome.crypto.{ECKeyPair, ECPublicKey} +import io.iohk.metronome.networking.ConnectionHandler.MessageReceived +import io.iohk.metronome.networking.RemoteConnectionManager.{ + ClusterConfig, + RetryConfig +} +import io.iohk.metronome.networking.RemoteConnectionManagerTestUtils._ +import io.iohk.metronome.networking.RemoteConnectionManagerWithScalanetProviderSpec.{ + Cluster, + buildTestConnectionManager +} +import io.iohk.metronome.logging.{HybridLog, HybridLogObject, LogTracer} +import io.iohk.scalanet.peergroup.dynamictls.DynamicTLSPeerGroup.FramingConfig +import io.iohk.scalanet.peergroup.PeerGroup + +import java.net.InetSocketAddress +import java.security.SecureRandom +import monix.eval.{Task, TaskLift, TaskLike} +import monix.execution.Scheduler +import monix.execution.UncaughtExceptionReporter +import org.scalatest.flatspec.AsyncFlatSpecLike +import org.scalatest.Inspectors +import org.scalatest.matchers.should.Matchers + +import scala.concurrent.duration._ +import scodec.Codec + +class RemoteConnectionManagerWithScalanetProviderSpec + extends AsyncFlatSpecLike + with Matchers { + import RemoteConnectionManagerWithScalanetProviderSpec.ecPublicKeyEncoder + + implicit val testScheduler = + Scheduler.fixedPool( + "RemoteConnectionManagerSpec", + 16, + reporter = UncaughtExceptionReporter { + case ex: IllegalStateException + if ex.getMessage.contains("executor not accepting a task") => + case _: PeerGroup.ChannelBrokenException[_] => + // Probably test already closed with some task running in the background. + case ex => + UncaughtExceptionReporter.default.reportFailure(ex) + } + ) + + implicit val timeOut = 10.seconds + + behavior of "RemoteConnectionManagerWithScalanetProvider" + + it should "start connectionManager without any connections" in customTestCaseResourceT( + buildTestConnectionManager[Task, ECPublicKey, TestMessage]() + ) { connectionManager => + for { + connections <- connectionManager.getAcquiredConnections + } yield assert(connections.isEmpty) + } + + it should "build fully connected cluster of 3 nodes" in customTestCaseResourceT( + Cluster.buildCluster(3) + ) { cluster => + for { + size <- cluster.clusterSize + eachNodeCount <- cluster.getEachNodeConnectionsCount + } yield { + Inspectors.forAll(eachNodeCount)(count => count shouldEqual 2) + size shouldEqual 3 + } + } + + it should "build fully connected cluster of 4 nodes" in customTestCaseResourceT( + Cluster.buildCluster(4) + ) { cluster => + for { + size <- cluster.clusterSize + eachNodeCount <- cluster.getEachNodeConnectionsCount + } yield { + Inspectors.forAll(eachNodeCount)(count => count shouldEqual 3) + size shouldEqual 4 + } + } + + it should "send and receive messages with other nodes in cluster" in customTestCaseResourceT( + Cluster.buildCluster(3) + ) { cluster => + for { + eachNodeCount <- cluster.getEachNodeConnectionsCount + sendResult <- cluster.sendMessageFromRandomNodeToAllOthers(MessageA(1)) + (sender, receivers) = sendResult + received <- Task.traverse(receivers.toList)(receiver => + cluster.getMessageFromNode(receiver) + ) + } yield { + Inspectors.forAll(eachNodeCount)(count => count shouldEqual 2) + receivers.size shouldEqual 2 + received.size shouldEqual 2 + //every node should have received the same message + Inspectors.forAll(received) { receivedMessage => + receivedMessage shouldBe MessageReceived(sender, MessageA(1)) + } + } + } + + it should "eventually reconnect to offline node" in customTestCaseResourceT( + Cluster.buildCluster(3) + ) { cluster => + for { + size <- cluster.clusterSize + killed <- cluster.shutdownRandomNode + _ <- cluster.sendMessageFromRandomNodeToAllOthers(MessageA(1)) + (address, keyPair, clusterConfig) = killed + _ <- cluster.waitUntilEveryNodeHaveNConnections(1) + // be offline for a moment + _ <- Task.sleep(3.seconds) + connectionAfterFailure <- cluster.getEachNodeConnectionsCount + _ <- cluster.startNode(address, keyPair, clusterConfig) + _ <- cluster.waitUntilEveryNodeHaveNConnections(2) + } yield { + size shouldEqual 3 + Inspectors.forAll(connectionAfterFailure) { connections => + connections shouldEqual 1 + } + } + } +} +object RemoteConnectionManagerWithScalanetProviderSpec { + val secureRandom = new SecureRandom() + val standardFraming = + FramingConfig.buildStandardFrameConfig(1000000, 4).getOrElse(null) + val testIncomingQueueSize = 20 + + implicit val ecPublicKeyEncoder: Encoder[ECPublicKey] = + Encoder.instance(key => Json.fromString(key.bytes.toHex)) + + // Just an example of setting up logging. + implicit def tracers[F[_]: Sync, K: io.circe.Encoder, M] + : NetworkTracers[F, K, M] = { + import io.circe.syntax._ + import NetworkEvent._ + + implicit val peerEncoder: Encoder.AsObject[Peer[K]] = + Encoder.AsObject.instance { case Peer(key, address) => + JsonObject("key" -> key.asJson, "address" -> address.toString.asJson) + } + + implicit val hybridLog: HybridLog[NetworkEvent[K, M]] = + HybridLog.instance[NetworkEvent[K, M]]( + level = _ => HybridLogObject.Level.Debug, + message = _.getClass.getSimpleName, + event = { + case e: ConnectionUnknown[_] => e.peer.asJsonObject + case e: ConnectionRegistered[_] => e.peer.asJsonObject + case e: ConnectionDeregistered[_] => e.peer.asJsonObject + case e: ConnectionDiscarded[_] => e.peer.asJsonObject + case e: ConnectionSendError[_] => e.peer.asJsonObject + case e: ConnectionFailed[_] => + e.peer.asJsonObject.add("error", e.error.toString.asJson) + case e: ConnectionReceiveError[_] => + e.peer.asJsonObject.add("error", e.error.toString.asJson) + case e: NetworkEvent.MessageReceived[_, _] => e.peer.asJsonObject + case e: NetworkEvent.MessageSent[_, _] => e.peer.asJsonObject + } + ) + + NetworkTracers(LogTracer.hybrid[F, NetworkEvent[K, M]]) + } + + def buildTestConnectionManager[ + F[_]: Concurrent: TaskLift: TaskLike: Timer, + K: Codec: Encoder, + M: Codec + ]( + bindAddress: InetSocketAddress = randomAddress(), + nodeKeyPair: ECKeyPair = ECKeyPair.generate(secureRandom), + secureRandom: SecureRandom = secureRandom, + useNativeTlsImplementation: Boolean = false, + framingConfig: FramingConfig = standardFraming, + maxIncomingQueueSizePerPeer: Int = testIncomingQueueSize, + clusterConfig: ClusterConfig[K] = ClusterConfig( + Set.empty[(K, InetSocketAddress)] + ), + retryConfig: RetryConfig = RetryConfig.default + )(implicit + s: Scheduler, + cs: ContextShift[F] + ): Resource[F, RemoteConnectionManager[F, K, M]] = { + ScalanetConnectionProvider + .scalanetProvider[F, K, M]( + bindAddress, + nodeKeyPair, + secureRandom, + useNativeTlsImplementation, + framingConfig, + maxIncomingQueueSizePerPeer + ) + .flatMap(prov => + RemoteConnectionManager(prov, clusterConfig, retryConfig) + ) + } + + type ClusterNodes = Map[ + ECPublicKey, + ( + RemoteConnectionManager[Task, ECPublicKey, TestMessage], + ECKeyPair, + ClusterConfig[ECPublicKey], + Task[Unit] + ) + ] + + def buildClusterNodes( + keys: NonEmptyList[NodeInfo] + )(implicit + s: Scheduler, + timeOut: FiniteDuration + ): Task[Ref[Task, ClusterNodes]] = { + val keyWithAddress = keys.toList.map(key => (key, randomAddress())).toSet + + for { + nodes <- Ref.of[Task, ClusterNodes](Map.empty) + _ <- Task.traverse(keyWithAddress) { case (info, address) => + val clusterConfig = ClusterConfig(clusterNodes = + keyWithAddress.map(keyWithAddress => + (keyWithAddress._1.keyPair.pub, keyWithAddress._2) + ) + ) + + buildTestConnectionManager[Task, ECPublicKey, TestMessage]( + bindAddress = address, + nodeKeyPair = info.keyPair, + clusterConfig = clusterConfig + ).allocated.flatMap { case (manager, release) => + nodes.update(map => + map + (manager.getLocalPeerInfo._1 -> (manager, info.keyPair, clusterConfig, release)) + ) + } + } + + } yield nodes + } + + class Cluster(nodes: Ref[Task, ClusterNodes]) { + + private def broadcastToAllConnections( + manager: RemoteConnectionManager[Task, ECPublicKey, TestMessage], + message: TestMessage + ) = { + manager.getAcquiredConnections.flatMap { connections => + Task + .parTraverseUnordered(connections)(connectionKey => + manager.sendMessage(connectionKey, message) + ) + .map { _ => + connections + } + } + + } + + def clusterSize: Task[Int] = nodes.get.map(_.size) + + def getEachNodeConnectionsCount: Task[List[Int]] = { + for { + runningNodes <- nodes.get.flatMap(nodes => + Task.traverse(nodes.values.map(_._1))(manager => + manager.getAcquiredConnections + ) + ) + + } yield runningNodes.map(_.size).toList + } + + def waitUntilEveryNodeHaveNConnections( + n: Int + )(implicit timeOut: FiniteDuration): Task[List[Int]] = { + getEachNodeConnectionsCount + .restartUntil(counts => + counts.forall(currentNodeConnectionCount => + currentNodeConnectionCount == n + ) + ) + .timeout(timeOut) + } + + def closeAllNodes: Task[Unit] = { + nodes.get.flatMap { nodes => + Task + .parTraverseUnordered(nodes.values) { case (_, _, _, release) => + release + } + .void + } + } + + def sendMessageFromRandomNodeToAllOthers( + message: TestMessage + ): Task[(ECPublicKey, Set[ECPublicKey])] = { + for { + runningNodes <- nodes.get + (key, (node, _, _, _)) = runningNodes.head + nodesReceivingMessage <- broadcastToAllConnections(node, message) + } yield (key, nodesReceivingMessage) + } + + def sendMessageFromAllClusterNodesToTheirConnections( + message: TestMessage + ): Task[List[(ECPublicKey, Set[ECPublicKey])]] = { + nodes.get.flatMap { current => + Task.parTraverseUnordered(current.values) { case (manager, _, _, _) => + broadcastToAllConnections(manager, message).map { receivers => + (manager.getLocalPeerInfo._1 -> receivers) + } + } + } + } + + def getMessageFromNode(key: ECPublicKey) = { + nodes.get.flatMap { runningNodes => + runningNodes(key)._1.incomingMessages.take(1).toListL.map(_.head) + } + } + + def shutdownRandomNode: Task[ + (InetSocketAddress, ECKeyPair, ClusterConfig[ECPublicKey]) + ] = { + for { + current <- nodes.get + ( + randomNodeKey, + (randomManager, nodeKeyPair, clusterConfig, randomRelease) + ) = current.head + _ <- randomRelease + _ <- nodes.update(current => current - randomNodeKey) + } yield (randomManager.getLocalPeerInfo._2, nodeKeyPair, clusterConfig) + } + + def startNode( + bindAddress: InetSocketAddress, + keyPair: ECKeyPair, + clusterConfig: ClusterConfig[ECPublicKey] + )(implicit s: Scheduler): Task[Unit] = { + buildTestConnectionManager[Task, ECPublicKey, TestMessage]( + bindAddress = bindAddress, + nodeKeyPair = keyPair, + clusterConfig = clusterConfig + ).allocated.flatMap { case (manager, release) => + nodes.update { current => + current + (manager.getLocalPeerInfo._1 -> (manager, keyPair, clusterConfig, release)) + } + } + } + + } + + object Cluster { + def buildCluster(size: Int)(implicit + s: Scheduler, + timeOut: FiniteDuration + ): Resource[Task, Cluster] = { + val nodeInfos = NonEmptyList.fromListUnsafe( + ((0 until size).map(_ => NodeInfo.generateRandom(secureRandom)).toList) + ) + + Resource.make { + for { + nodes <- buildClusterNodes(nodeInfos) + cluster = new Cluster(nodes) + _ <- cluster.getEachNodeConnectionsCount + .restartUntil(counts => counts.forall(count => count == size - 1)) + .timeout(timeOut) + } yield cluster + } { cluster => cluster.closeAllNodes } + } + + } + +} diff --git a/metronome/rocksdb/src/io/iohk/metronome/rocksdb/RocksDBStore.scala b/metronome/rocksdb/src/io/iohk/metronome/rocksdb/RocksDBStore.scala new file mode 100644 index 00000000..8134ea25 --- /dev/null +++ b/metronome/rocksdb/src/io/iohk/metronome/rocksdb/RocksDBStore.scala @@ -0,0 +1,519 @@ +package io.iohk.metronome.rocksdb + +import cats._ +import cats.implicits._ +import cats.data.ReaderT +import cats.effect.{Resource, Sync, ContextShift, Blocker} +import io.iohk.metronome.storage.{ + KVStore, + KVStoreOp, + KVStoreRead, + KVStoreReadOp +} +import io.iohk.metronome.storage.KVStoreOp.{Put, Get, Delete} +import java.util.concurrent.locks.ReentrantReadWriteLock +import org.rocksdb.{ + RocksDB, + WriteBatch, + WriteOptions, + ReadOptions, + Options, + DBOptions, + ColumnFamilyOptions, + ColumnFamilyDescriptor, + ColumnFamilyHandle, + BlockBasedTableConfig, + BloomFilter, + CompressionType, + ClockCache +} +import scodec.{Encoder, Decoder} +import scodec.bits.BitVector +import scala.collection.mutable +import java.nio.file.Path +import scala.annotation.nowarn + +/** Implementation of intepreters for `KVStore[N, A]` and `KVStoreRead[N, A]`operations + * with various semantics. Application code is not expected to interact with this class + * directly. Instead, some middle layer should be passed as a dependency to code that + * delegates to the right interpreter in this class. + * + * For example if our data schema is append-only, there's no need to pay the performance + * penalty for using locking, or if two parts of the application are isolated from each other, + * locking could be performed in their respective middle-layers, before they forward the + * query for execution to this class. + */ +class RocksDBStore[F[_]: Sync: ContextShift]( + db: RocksDBStore.DBSupport, + lock: RocksDBStore.LockSupport, + blocker: Blocker, + handles: Map[RocksDBStore.Namespace, ColumnFamilyHandle] +) { + + import RocksDBStore.{Namespace, DBQuery, autoCloseableR} + + private val kvs = KVStore.instance[Namespace] + + // Batch execution needs these variables for accumulating operations + // and executing them against the database. They are going to be + // passed along in a Reader monad to the Free compiler. + type BatchEnv = WriteBatch + + // Type aliases to support the `~>` transformation with types that + // only have 1 generic type argument `A`. + type Batch[A] = + ({ type L[A] = ReaderT[Eval, BatchEnv, A] })#L[A] + + type KVNamespacedOp[A] = + ({ type L[A] = KVStoreOp[Namespace, A] })#L[A] + + type KVNamespacedReadOp[A] = + ({ type L[A] = KVStoreReadOp[Namespace, A] })#L[A] + + /** Execute the accumulated write operations in a batch. */ + private val writeBatch: ReaderT[Eval, BatchEnv, Unit] = + ReaderT { batch => + if (batch.hasPut() || batch.hasDelete()) + db.write(batch) >> + Eval.always(batch.clear()) + else + ().pure[Eval] + } + + /** Execute one `Get` operation. */ + private def get[K, V](op: Get[Namespace, K, V]): Eval[Option[V]] = { + for { + kbs <- encode(op.key)(op.keyEncoder) + mvbs <- db.get(handles(op.namespace), kbs) + mv <- mvbs match { + case None => + none.pure[Eval] + + case Some(bytes) => + decode(bytes)(op.valueDecoder).map(_.some) + } + } yield mv + } + + /** Collect writes in a batch, until we either get to the end, or there's a read. + * This way writes are atomic, and reads can see their effect. + * + * In the next version of RocksDB we can use transactions to make reads and writes + * run in isolation. + */ + private val batchingCompiler: KVNamespacedOp ~> Batch = + new (KVNamespacedOp ~> Batch) { + def apply[A](fa: KVNamespacedOp[A]): Batch[A] = + fa match { + case op @ Put(n, k, v) => + ReaderT { batch => + for { + kbs <- encode(k)(op.keyEncoder) + vbs <- encode(v)(op.valueEncoder) + _ = batch.put(handles(n), kbs, vbs) + } yield () + } + + case op @ Get(_, _) => + // Execute any pending deletes and puts before performing the read. + writeBatch >> ReaderT.liftF(get(op)) + + case op @ Delete(n, k) => + ReaderT { batch => + for { + kbs <- encode(k)(op.keyEncoder) + _ = batch.delete(handles(n), kbs) + } yield () + } + } + } + + /** Intended for reads, with fallback to writes executed individually. */ + private val nonBatchingCompiler: KVNamespacedOp ~> Eval = + new (KVNamespacedOp ~> Eval) { + def apply[A](fa: KVNamespacedOp[A]): Eval[A] = + fa match { + case op @ Get(_, _) => + get(op) + + case op @ Put(n, k, v) => + for { + kbs <- encode(k)(op.keyEncoder) + vbs <- encode(v)(op.valueEncoder) + _ <- lock.withLockUpgrade(db.put(handles(n), kbs, vbs)) + } yield () + + case op @ Delete(n, k) => + for { + kbs <- encode(k)(op.keyEncoder) + _ <- lock.withLockUpgrade(db.delete(handles(n), kbs)) + } yield () + } + } + + private def encode[T](value: T)(implicit ev: Encoder[T]): Eval[Array[Byte]] = + Eval.always(ev.encode(value).map(_.toByteArray).require) + + private def decode[T](bytes: Array[Byte])(implicit ev: Decoder[T]): Eval[T] = + Eval.always(ev.decodeValue(BitVector(bytes)).require) + + private def block[A](evalA: Eval[A]): F[A] = + ContextShift[F].blockOn(blocker) { + Sync[F].delay(evalA.value) + } + + /** Mostly meant for writing batches atomically. + * + * If a read is found the accumulated writes are performed, + * then the read happens, before batching carries on; + * this breaks the atomicity of writes. + * + * This version doesn't use any locking, so it's suitable for + * append-only data stores, or writing to independent stores + * in parallel. + */ + def runWithBatchingNoLock[A]( + program: KVStore[Namespace, A] + ): DBQuery[F, A] = + autoCloseableR(new WriteBatch()).use { batch => + block { + (program.foldMap(batchingCompiler) <* writeBatch).run(batch) + } + } + + /** Same as `runWithBatchingNoLock`, but write lock is taken out + * to make sure concurrent reads are not affected. + * + * This version is suitable for cases where data may be deleted, + * which could result for example in foreign key references + * becoming invalid after they are read, before the data they + * point to is retrieved. + */ + def runWithBatching[A](program: KVStore[Namespace, A]): DBQuery[F, A] = + autoCloseableR(new WriteBatch()).use { batch => + block { + lock.withWriteLock { + (program.foldMap(batchingCompiler) <* writeBatch).run(batch) + } + } + } + + /** Similar to `runWithBatching` in that it can contain both reads + * and writes, but the expectation is that it will mostly be reads. + * + * A read lock is taken out to make sure writes don't affect reads; + * if a write is found, it is executed as an individual operation, + * while a write lock is taken out to protect other reads. Note that + * this breaks the isolation of reads, because to acquire a write lock, + * the read lock has to be released, which gives a chance for other + * threads to get in before the write statement runs. + */ + def runWithoutBatching[A](program: KVStore[Namespace, A]): DBQuery[F, A] = + block { + lock.withReadLock { + program.foldMap(nonBatchingCompiler) + } + } + + /** For strictly read-only operations. + * + * Doesn't use locking, so most suitable for append-only data schemas + * where reads don't need isolation from writes. + */ + def runReadOnlyNoLock[A](program: KVStoreRead[Namespace, A]): DBQuery[F, A] = + block { + kvs.lift(program).foldMap(nonBatchingCompiler) + } + + /** Same as `runReadOnlyNoLock`, but a read lock is taken out + * to make sure concurrent writes cannot affect the results. + * + * This version is suitable for use cases where destructive + * updates are happening. + */ + def runReadOnly[A](program: KVStoreRead[Namespace, A]): DBQuery[F, A] = + block { + lock.withReadLock { + kvs.lift(program).foldMap(nonBatchingCompiler) + } + } +} + +object RocksDBStore { + type Namespace = IndexedSeq[Byte] + + /** Database operations may fail due to a couple of reasons: + * - database connection issues + * - obsolete format stored, codec unable to read data + * + * But it's not expected, so just using `F[A]` for now, + * rather than `EitherT[F, Throwable, A]`. + */ + type DBQuery[F[_], A] = F[A] + + case class Config( + path: Path, + createIfMissing: Boolean, + paranoidChecks: Boolean, + maxThreads: Int, + maxOpenFiles: Int, + verifyChecksums: Boolean, + levelCompaction: Boolean, + blockSizeBytes: Long, + blockCacheSizeBytes: Long + ) + object Config { + def default(path: Path): Config = + Config( + path = path, + // Create DB data directory if it's missing + createIfMissing = true, + // Should the DB raise an error as soon as it detects an internal corruption + paranoidChecks = true, + maxThreads = 1, + maxOpenFiles = 32, + // Force checksum verification of all data that is read from the file system on behalf of a particular read. + verifyChecksums = true, + // In this mode, size target of levels are changed dynamically based on size of the last level. + // https://rocksdb.org/blog/2015/07/23/dynamic-level.html + levelCompaction = true, + // Approximate size of user data packed per block (16 * 1024) + blockSizeBytes = 16384, + // Amount of cache in bytes that will be used by RocksDB (32 * 1024 * 1024) + blockCacheSizeBytes = 33554432 + ) + } + + def apply[F[_]: Sync: ContextShift]( + config: Config, + namespaces: Seq[Namespace] + ): Resource[F, RocksDBStore[F]] = { + + @nowarn // JavaConverters are deprecated in 2.13 + def open( + opts: DBOptions, + cfds: Seq[ColumnFamilyDescriptor], + cfhs: mutable.Buffer[ColumnFamilyHandle] + ): RocksDB = { + import scala.collection.JavaConverters._ + RocksDB.open(opts, config.path.toString, cfds.asJava, cfhs.asJava) + } + + // There is a specific order for closing RocksDB with column families described in + // https://github.com/facebook/rocksdb/wiki/RocksJava-Basics#opening-a-database-with-column-families + // 1. Free all column families handles + // 2. Free DB and DB options + // 3. Free column families options + // So they are created in the opposite order. + for { + _ <- Resource.liftF[F, Unit](Sync[F].delay { + RocksDB.loadLibrary() + }) + + tableConf <- Resource.pure[F, BlockBasedTableConfig] { + mkTableConfig(config) + } + + cfOpts <- autoCloseableR[F, ColumnFamilyOptions] { + new ColumnFamilyOptions() + .setCompressionType(CompressionType.LZ4_COMPRESSION) + .setBottommostCompressionType(CompressionType.ZSTD_COMPRESSION) + .setLevelCompactionDynamicLevelBytes(config.levelCompaction) + .setTableFormatConfig(tableConf) + } + + allNamespaces = RocksDB.DEFAULT_COLUMN_FAMILY.toIndexedSeq +: namespaces + + cfDescriptors = allNamespaces.map { n => + new ColumnFamilyDescriptor(n.toArray, cfOpts) + } + + dbOpts <- autoCloseableR[F, DBOptions] { + new DBOptions() + .setCreateIfMissing(config.createIfMissing) + .setParanoidChecks(config.paranoidChecks) + .setMaxOpenFiles(config.maxOpenFiles) + .setIncreaseParallelism(config.maxThreads) + .setCreateMissingColumnFamilies(true) + } + + readOpts <- autoCloseableR[F, ReadOptions] { + new ReadOptions().setVerifyChecksums(config.verifyChecksums) + } + writeOptions <- autoCloseableR[F, WriteOptions] { + new WriteOptions() + } + + // The handles will be filled as the database is opened. + columnFamilyHandleBuffer = mutable.Buffer.empty[ColumnFamilyHandle] + + db <- autoCloseableR[F, RocksDB] { + open( + dbOpts, + cfDescriptors, + columnFamilyHandleBuffer + ) + } + + columnFamilyHandles <- Resource.make( + (allNamespaces zip columnFamilyHandleBuffer).toMap.pure[F] + ) { _ => + // Make sure all handles are closed, and this happens before the DB is closed. + Sync[F].delay(columnFamilyHandleBuffer.foreach(_.close())) + } + + // Sanity check; if an exception is raised everything will be closed down. + _ = assert( + columnFamilyHandleBuffer.size == allNamespaces.size, + "Should have created a column family handle for each namespace." + + s" Expected ${allNamespaces.size}; got ${columnFamilyHandleBuffer.size}." + ) + + // Use a cached thread pool for blocking on locks and IO. + blocker <- Blocker[F] + + store = new RocksDBStore[F]( + new DBSupport(db, readOpts, writeOptions), + new LockSupport(new ReentrantReadWriteLock()), + blocker, + columnFamilyHandles + ) + + } yield store + } + + /** Remove the database directory. */ + def destroy[F[_]: Sync]( + config: Config + ): F[Unit] = { + autoCloseableR[F, Options] { + new Options() + .setCreateIfMissing(config.createIfMissing) + .setParanoidChecks(config.paranoidChecks) + .setCompressionType(CompressionType.LZ4_COMPRESSION) + .setBottommostCompressionType(CompressionType.ZSTD_COMPRESSION) + .setLevelCompactionDynamicLevelBytes(config.levelCompaction) + .setMaxOpenFiles(config.maxOpenFiles) + .setIncreaseParallelism(config.maxThreads) + .setTableFormatConfig(mkTableConfig(config)) + }.use { options => + Sync[F].delay { + RocksDB.destroyDB(config.path.toString, options) + } + } + } + + private def mkTableConfig(config: Config): BlockBasedTableConfig = + new BlockBasedTableConfig() + .setBlockSize(config.blockSizeBytes) + .setBlockCache(new ClockCache(config.blockCacheSizeBytes)) + .setCacheIndexAndFilterBlocks(true) + .setPinL0FilterAndIndexBlocksInCache(true) + .setFilterPolicy(new BloomFilter(10, false)) + + private def autoCloseableR[F[_]: Sync, R <: AutoCloseable]( + mk: => R + ): Resource[F, R] = + Resource.fromAutoCloseable[F, R](Sync[F].delay(mk)) + + /** Help run reads and writes isolated from each other. + * + * Uses a `ReentrantReadWriteLock` so has to make sure that + * all operations are carried out on the same thread, that's + * why it's working with `Eval` and not `F`. + */ + private class LockSupport( + rwlock: ReentrantReadWriteLock + ) { + + // Batches can interleave multiple reads (and writes); + // to make sure they see a consistent view, writes are + // isolated from reads via locks, so for example if we + // read an ID, then retrieve the record from a different + // collection, we can be sure it hasn't been deleted in + // between the two operations. + private val lockRead = Eval.always { + rwlock.readLock().lock() + } + private val unlockRead = Eval.always { + rwlock.readLock().unlock() + } + private val lockWrite = Eval.always { + rwlock.writeLock().lock() + } + private val unlockWrite = Eval.always { + rwlock.writeLock().unlock() + } + + def withReadLock[A](evalA: Eval[A]): Eval[A] = + bracket(lockRead, unlockRead)(evalA) + + def withWriteLock[A](evalA: Eval[A]): Eval[A] = + bracket(lockWrite, unlockWrite)(evalA) + + /* + * In case there's a write operation among the reads and we haven't + * taken out a write lock, we can replace the the read lock we have + * with a write lock, for the duration of the operation, then downgrade + * it back to when we're done. + * + * Note that *technically* this is not an upgrade: to acquire the write + * lock, the read lock has to be released first, therefore other threads + * may get the write lock first. It works in the other direction though: + * the write lock can be turned into a read. + * + * See here for the rules up (non-)upgrading and downgrading: + * https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReentrantReadWriteLock.html + */ + def withLockUpgrade[A](fa: Eval[A]): Eval[A] = + bracket( + unlockRead >> lockWrite, + lockRead >> unlockWrite + )(fa) + + private def bracket[A](lock: Eval[Unit], unlock: Eval[Unit])( + evalA: Eval[A] + ): Eval[A] = Eval.always { + try { + (lock >> evalA).value + } finally { + unlock.value + } + } + } + + /** Wrap a RocksDB instance. */ + private class DBSupport( + db: RocksDB, + readOptions: ReadOptions, + writeOptions: WriteOptions + ) { + def get( + handle: ColumnFamilyHandle, + key: Array[Byte] + ): Eval[Option[Array[Byte]]] = Eval.always { + Option(db.get(handle, readOptions, key)) + } + + def write( + batch: WriteBatch + ): Eval[Unit] = Eval.always { + db.write(writeOptions, batch) + } + + def put( + handle: ColumnFamilyHandle, + key: Array[Byte], + value: Array[Byte] + ): Eval[Unit] = Eval.always { + db.put(handle, writeOptions, key, value) + } + + def delete( + handle: ColumnFamilyHandle, + key: Array[Byte] + ): Eval[Unit] = Eval.always { + db.delete(handle, writeOptions, key) + } + } +} diff --git a/metronome/rocksdb/test/src/io/iohk/metronome/rocksdb/RocksDBStoreProps.scala b/metronome/rocksdb/test/src/io/iohk/metronome/rocksdb/RocksDBStoreProps.scala new file mode 100644 index 00000000..27eb7cd2 --- /dev/null +++ b/metronome/rocksdb/test/src/io/iohk/metronome/rocksdb/RocksDBStoreProps.scala @@ -0,0 +1,543 @@ +package io.iohk.metronome.rocksdb + +import cats.implicits._ +import cats.effect.Resource +import io.iohk.metronome.storage.{ + KVStoreState, + KVStore, + KVCollection, + KVStoreRead +} +import java.nio.file.Files +import monix.eval.Task +import org.scalacheck.commands.Commands +import org.scalacheck.{Properties, Gen, Prop, Test, Arbitrary} +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Prop.forAll +import scala.util.{Try, Success} +import scala.concurrent.duration._ +import scala.annotation.nowarn +import scodec.bits.ByteVector +import scodec.codecs.implicits._ + +// https://github.com/typelevel/scalacheck/blob/master/doc/UserGuide.md#stateful-testing +// https://github.com/typelevel/scalacheck/blob/master/examples/commands-redis/src/test/scala/CommandsRedis.scala + +object RocksDBStoreProps extends Properties("RocksDBStore") { + + override def overrideParameters(p: Test.Parameters): Test.Parameters = + p.withMinSuccessfulTests(20).withMaxSize(100) + + // Equivalent to the in-memory model. + property("equivalent") = RocksDBStoreCommands.property() + + // Run reads and writes concurrently. + property("linearizable") = forAll { + import RocksDBStoreCommands._ + for { + empty <- genInitialState + // Generate some initial data. Puts are the only useful op. + init <- Gen.listOfN(50, genPut(empty)).map { ops => + ReadWriteProgram(ops.toList.sequence, batching = true) + } + state = init.nextState(empty) + // The first program is read/write, it takes a write lock. + prog1 <- genReadWriteProg(state).map(_.copy(batching = true)) + // The second program is read-only, it takes a read lock. + prog2 <- genReadOnlyProg(state) + } yield (init, state, prog1, prog2) + } { case (init, state, prog1, prog2) => + import RocksDBStoreCommands._ + + val sut = newSut(state) + try { + // Connect to the database. + ToggleConnected.run(sut) + // Initialize the database. + init.run(sut) + + // Run them concurrently. They should be serialised. + val (result1, result2) = await { + Task.parMap2(Task(prog1.run(sut)), Task(prog2.run(sut)))((_, _)) + } + + // Need to chain together Read-Write and Read-Only ops to test them as one program. + val liftedRO = + KVStore.instance[RocksDBStore.Namespace].lift(prog2.program) + + // Overall the results should correspond to either prog1 ++ prog2, or prog2 ++ prog1. + val prog12 = ReadWriteProgram((prog1.program, liftedRO).mapN(_ ++ _)) + val prog21 = ReadWriteProgram((liftedRO, prog1.program).mapN(_ ++ _)) + + // One of them should have run first. + val prop1 = prog1.postCondition(state, Success(result1)) + val prop2 = prog2.postCondition(state, Success(result2)) + // The other should run second, on top of the changes from the first. + val prop12 = prog12.postCondition(state, Success(result1 ++ result2)) + val prop21 = prog21.postCondition(state, Success(result2 ++ result1)) + + (prop1 && prop12) || (prop2 && prop21) + } finally { + destroySut(sut) + } + } +} + +object RocksDBStoreCommands extends Commands { + import RocksDBStore.Namespace + + // The in-memory implementation is our reference execution model. + object InMemoryKVS extends KVStoreState[Namespace] + + // Some structured data to be stored in the database. + case class TestRecord(id: ByteVector, name: String, value: Int) + + // Symbolic state of the test. + case class Model( + // Support opening/closing the database to see if it can read back the files it has created. + isConnected: Boolean, + namespaces: IndexedSeq[Namespace], + store: InMemoryKVS.Store, + deleted: Map[Namespace, Set[Any]], + // Some collections so we have typed access. + coll0: KVCollection[Namespace, String, Int], + coll1: KVCollection[Namespace, Int, ByteVector], + coll2: KVCollection[Namespace, ByteVector, TestRecord] + ) { + + def storeOf(coll: Coll): Map[Any, Any] = + store.getOrElse(namespaces(coll.idx), Map.empty) + + def nonEmptyColls: List[Coll] = + Colls.filter(c => storeOf(c).nonEmpty) + } + sealed trait Coll { + def idx: Int + } + case object Coll0 extends Coll { def idx = 0 } + case object Coll1 extends Coll { def idx = 1 } + case object Coll2 extends Coll { def idx = 2 } + + val Colls = List(Coll0, Coll1, Coll2) + + case class Allocated[T](value: T, release: Task[Unit]) + + class Database( + val namespaces: Seq[Namespace], + val config: Allocated[RocksDBStore.Config], + var maybeConnection: Option[Allocated[RocksDBStore[Task]]] + ) + + type State = Model + type Sut = Database + + def await[T](task: Task[T]): T = { + import monix.execution.Scheduler.Implicits.global + task.runSyncUnsafe(timeout = 10.seconds) + } + + /** Run one database at any time. */ + @nowarn // Traversable deprecated in 2.13 + override def canCreateNewSut( + newState: State, + initSuts: Traversable[State], + runningSuts: Traversable[Sut] + ): Boolean = + initSuts.isEmpty && runningSuts.isEmpty + + /** Start with an empty database. */ + override def initialPreCondition(state: State): Boolean = + state.store.isEmpty && !state.isConnected + + /** Create a new empty database. */ + override def newSut(state: State): Sut = { + val res = for { + path <- Resource.make(Task { + Files.createTempDirectory("testdb") + }) { path => + Task { + if (Files.exists(path)) Files.delete(path) + } + } + + config = RocksDBStore.Config.default(path) + + _ <- Resource.make(Task.unit) { _ => + RocksDBStore.destroy[Task](config) + } + } yield config + + await { + res.allocated.map { case (config, release) => + new Database( + state.namespaces, + Allocated(config, release), + maybeConnection = None + ) + } + } + } + + /** Release the database and all resources. */ + override def destroySut(sut: Sut): Unit = + await { + sut.maybeConnection + .fold(Task.unit)(_.release) + .guarantee(sut.config.release) + } + + /** Initialise a fresh model state. */ + override def genInitialState: Gen[State] = + for { + // Generate at least 3 unique namespaces. + n <- Gen.choose(3, 10) + ns <- Gen + .listOfN(n, arbitrary[ByteVector].suchThat(_.nonEmpty)) + .map(_.distinct) + .suchThat(_.size >= 3) + namespaces = ns.map(_.toIndexedSeq).toIndexedSeq + } yield Model( + isConnected = false, + namespaces = namespaces, + store = Map.empty, + deleted = Map.empty, + coll0 = new KVCollection[Namespace, String, Int](namespaces(0)), + coll1 = new KVCollection[Namespace, Int, ByteVector](namespaces(1)), + coll2 = new KVCollection[Namespace, ByteVector, TestRecord](namespaces(2)) + ) + + /** Produce a Command based on the current model state. */ + def genCommand(state: State): Gen[Command] = + if (!state.isConnected) Gen.const(ToggleConnected) + else + Gen.frequency( + 10 -> genReadWriteProg(state), + 3 -> genReadOnlyProg(state), + 1 -> Gen.const(ToggleConnected) + ) + + /** Generate a sequence of writes and reads. */ + def genReadWriteProg(state: State): Gen[ReadWriteProgram] = + for { + batching <- arbitrary[Boolean] + n <- Gen.choose(0, 30) + ops <- Gen.listOfN( + n, + Gen.frequency( + 10 -> genPut(state), + 30 -> genPutExisting(state), + 5 -> genDel(state), + 15 -> genDelExisting(state), + 5 -> genGet(state), + 30 -> genGetExisting(state), + 5 -> genGetDeleted(state) + ) + ) + program = ops.toList.sequence + } yield ReadWriteProgram(program, batching) + + /** Generate a read-only operations. */ + def genReadOnlyProg(state: State): Gen[ReadOnlyProgram] = + for { + n <- Gen.choose(0, 10) + ops <- Gen.listOfN( + n, + Gen.frequency( + 1 -> genRead(state), + 4 -> genReadExisting(state) + ) + ) + program = ops.toList.sequence + } yield ReadOnlyProgram(program) + + implicit val arbColl: Arbitrary[Coll] = Arbitrary { + Gen.oneOf(Coll0, Coll1, Coll2) + } + + implicit val arbByteVector: Arbitrary[ByteVector] = Arbitrary { + arbitrary[Array[Byte]].map(ByteVector(_)) + } + + implicit val arbTestRecord: Arbitrary[TestRecord] = Arbitrary { + for { + id <- arbitrary[ByteVector] + name <- Gen.alphaNumStr + value <- arbitrary[Int] + } yield TestRecord(id, name, value) + } + + def genPut(state: State): Gen[KVStore[Namespace, Any]] = + arbitrary[Coll] flatMap { + case Coll0 => + for { + k <- Gen.alphaLowerStr.suchThat(_.nonEmpty) + v <- arbitrary[Int] + } yield state.coll0.put(k, v) + + case Coll1 => + for { + k <- arbitrary[Int] + v <- arbitrary[ByteVector] + } yield state.coll1.put(k, v) + + case Coll2 => + for { + k <- arbitrary[ByteVector].suchThat(_.nonEmpty) + v <- arbitrary[TestRecord] + } yield state.coll2.put(k, v) + } map { + _.map(_.asInstanceOf[Any]) + } + + def genPutExisting(state: State): Gen[KVStore[Namespace, Any]] = + state.nonEmptyColls match { + case Nil => + genPut(state) + + case colls => + for { + c <- Gen.oneOf(colls) + k <- Gen.oneOf(state.storeOf(c).keySet) + op <- c match { + case Coll0 => + arbitrary[Int].map { v => + state.coll0.put(k.asInstanceOf[String], v) + } + case Coll1 => + arbitrary[ByteVector].map { v => + state.coll1.put(k.asInstanceOf[Int], v) + } + case Coll2 => + arbitrary[TestRecord].map { v => + state.coll2.put(k.asInstanceOf[ByteVector], v) + } + } + } yield op.map(_.asInstanceOf[Any]) + } + + def genDel(state: State): Gen[KVStore[Namespace, Any]] = + arbitrary[Coll] flatMap { + case Coll0 => + arbitrary[String].map(state.coll0.delete) + case Coll1 => + arbitrary[Int].map(state.coll1.delete) + case Coll2 => + arbitrary[ByteVector].map(state.coll2.delete) + } map { + _.map(_.asInstanceOf[Any]) + } + + def genDelExisting(state: State): Gen[KVStore[Namespace, Any]] = + state.nonEmptyColls match { + case Nil => + genGet(state) + + case colls => + for { + c <- Gen.oneOf(colls) + k <- Gen.oneOf(state.storeOf(c).keySet) + op = c match { + case Coll0 => + state.coll0.delete(k.asInstanceOf[String]) + case Coll1 => + state.coll1.delete(k.asInstanceOf[Int]) + case Coll2 => + state.coll2.delete(k.asInstanceOf[ByteVector]) + } + } yield op.map(_.asInstanceOf[Any]) + } + + def genGet(state: State): Gen[KVStore[Namespace, Any]] = + arbitrary[Coll] flatMap { + case Coll0 => + arbitrary[String].map(state.coll0.get) + case Coll1 => + arbitrary[Int].map(state.coll1.get) + case Coll2 => + arbitrary[ByteVector].map(state.coll2.get) + } map { + _.map(_.asInstanceOf[Any]) + } + + def genGetExisting(state: State): Gen[KVStore[Namespace, Any]] = + state.nonEmptyColls match { + case Nil => + genGet(state) + + case colls => + for { + c <- Gen.oneOf(colls) + k <- Gen.oneOf(state.storeOf(c).keySet) + op = c match { + case Coll0 => + state.coll0.get(k.asInstanceOf[String]) + case Coll1 => + state.coll1.get(k.asInstanceOf[Int]) + case Coll2 => + state.coll2.get(k.asInstanceOf[ByteVector]) + } + } yield op.map(_.asInstanceOf[Any]) + } + + def genGetDeleted(state: State): Gen[KVStore[Namespace, Any]] = { + val hasDeletes = + Colls + .map { c => + c -> state.namespaces(c.idx) + } + .filter { case (_, n) => + state.deleted.getOrElse(n, Set.empty).nonEmpty + } + + hasDeletes match { + case Nil => + genGet(state) + + case deletes => + for { + cn <- Gen.oneOf(deletes) + (c, n) = cn + k <- Gen.oneOf(state.deleted(n)) + op = c match { + case Coll0 => + state.coll0.get(k.asInstanceOf[String]) + case Coll1 => + state.coll1.get(k.asInstanceOf[Int]) + case Coll2 => + state.coll2.get(k.asInstanceOf[ByteVector]) + } + } yield op.map(_.asInstanceOf[Any]) + } + } + + def genRead(state: State): Gen[KVStoreRead[Namespace, Any]] = + arbitrary[Coll] flatMap { + case Coll0 => + arbitrary[String].map(state.coll0.read) + case Coll1 => + arbitrary[Int].map(state.coll1.read) + case Coll2 => + arbitrary[ByteVector].map(state.coll2.read) + } map { + _.map(_.asInstanceOf[Any]) + } + + def genReadExisting(state: State): Gen[KVStoreRead[Namespace, Any]] = + state.nonEmptyColls match { + case Nil => + genRead(state) + + case colls => + for { + c <- Gen.oneOf(colls) + k <- Gen.oneOf(state.storeOf(c).keySet) + op = c match { + case Coll0 => + state.coll0.read(k.asInstanceOf[String]) + case Coll1 => + state.coll1.read(k.asInstanceOf[Int]) + case Coll2 => + state.coll2.read(k.asInstanceOf[ByteVector]) + } + } yield op.map(_.asInstanceOf[Any]) + } + + /** Open or close the database. */ + case object ToggleConnected extends UnitCommand { + def run(sut: Sut) = { + sut.maybeConnection match { + case Some(connection) => + await(connection.release) + sut.maybeConnection = None + + case None => + val connection = await { + RocksDBStore[Task](sut.config.value, sut.namespaces).allocated + .map { case (db, release) => + Allocated(db, release) + } + } + sut.maybeConnection = Some(connection) + } + } + + def preCondition(state: State) = true + def nextState(state: State) = state.copy( + isConnected = !state.isConnected + ) + def postCondition(state: State, succeeded: Boolean) = succeeded + } + + case class ReadWriteProgram( + program: KVStore[Namespace, List[Any]], + batching: Boolean = false + ) extends Command { + // Collect all results from a batch of execution steps. + type Result = List[Any] + + def run(sut: Sut): Result = { + val db = sut.maybeConnection + .getOrElse(sys.error("The database is not connected.")) + .value + + await { + if (batching) { + db.runWithBatching(program) + } else { + db.runWithoutBatching(program) + } + } + } + + def preCondition(state: State): Boolean = state.isConnected + + def nextState(state: State): State = { + val nextStore = InMemoryKVS.compile(program).runS(state.store).value + + // Leave only what's still deleted. Add what's been deleted now. + val nextDeleted = state.deleted.map { case (n, ks) => + val existing = nextStore.getOrElse(n, Map.empty).keySet + n -> ks.filterNot(existing) + } ++ state.store.map { case (n, kvs) => + val existing = nextStore.getOrElse(n, Map.empty).keySet + n -> (kvs.keySet -- existing) + } + + state.copy( + store = nextStore, + deleted = nextDeleted + ) + } + + def postCondition(state: Model, result: Try[Result]): Prop = { + val expected = InMemoryKVS.compile(program).runA(state.store).value + result == Success(expected) + } + } + + case class ReadOnlyProgram( + program: KVStoreRead[Namespace, List[Any]] + ) extends Command { + // Collect all results from a batch of execution steps. + type Result = List[Any] + + def run(sut: Sut): Result = { + val db = sut.maybeConnection + .getOrElse(sys.error("The database is not connected.")) + .value + + await { + db.runReadOnly(program) + } + } + + def preCondition(state: State): Boolean = state.isConnected + + def nextState(state: State): State = state + + def postCondition(state: Model, result: Try[Result]): Prop = { + val expected = InMemoryKVS.compile(program).run(state.store) + result == Success(expected) + } + } +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/InMemoryKVStore.scala b/metronome/storage/src/io/iohk/metronome/storage/InMemoryKVStore.scala new file mode 100644 index 00000000..ee17b901 --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/InMemoryKVStore.scala @@ -0,0 +1,24 @@ +package io.iohk.metronome.storage + +import cats.implicits._ +import cats.effect.Sync +import cats.effect.concurrent.Ref + +/** Simple in-memory key-value store based on `KVStoreState` and `KVStoreRunner`. */ +object InMemoryKVStore { + def apply[F[_]: Sync, N]: F[KVStoreRunner[F, N]] = + Ref.of[F, KVStoreState[N]#Store](Map.empty).map(apply(_)) + + def apply[F[_]: Sync, N]( + storeRef: Ref[F, KVStoreState[N]#Store] + ): KVStoreRunner[F, N] = + new KVStoreState[N] with KVStoreRunner[F, N] { + def runReadOnly[A](query: KVStoreRead[N, A]): F[A] = + storeRef.get.map(compile(query).run) + + def runReadWrite[A](query: KVStore[N, A]): F[A] = + storeRef.modify { store => + compile(query).run(store).value + } + } +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVCollection.scala b/metronome/storage/src/io/iohk/metronome/storage/KVCollection.scala new file mode 100644 index 00000000..927ff55a --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVCollection.scala @@ -0,0 +1,41 @@ +package io.iohk.metronome.storage + +import scodec.Codec + +/** Storage for a specific type of data, e.g. blocks, in a given namespace. + * + * We should be able to string together KVStore operations across multiple + * collections and execute them in one batch. + */ +class KVCollection[N, K: Codec, V: Codec](namespace: N) { + + private implicit val kvsRW = KVStore.instance[N] + private implicit val kvsRO = KVStoreRead.instance[N] + + /** Get a value by key, if it exists, for a read-only operation. */ + def read(key: K): KVStoreRead[N, Option[V]] = + KVStoreRead[N].read(namespace, key) + + /** Put a value under a key. */ + def put(key: K, value: V): KVStore[N, Unit] = + KVStore[N].put(namespace, key, value) + + /** Get a value by key, if it exists, for potentially doing + * updates based on its value, i.e. the result can be composed + * with `put` and `delete`. + */ + def get(key: K): KVStore[N, Option[V]] = + KVStore[N].get(namespace, key) + + /** Delete a value by key. */ + def delete(key: K): KVStore[N, Unit] = + KVStore[N].delete(namespace, key) + + /** Update a key by getting the value and applying a function on it, if the value exists. */ + def update(key: K)(f: V => V): KVStore[N, Unit] = + KVStore[N].update(namespace, key)(f) + + /** Insert, update or delete a value, depending on whether it exists. */ + def alter(key: K)(f: Option[V] => Option[V]): KVStore[N, Unit] = + KVStore[N].alter(namespace, key)(f) +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVRingBuffer.scala b/metronome/storage/src/io/iohk/metronome/storage/KVRingBuffer.scala new file mode 100644 index 00000000..30ed83ab --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVRingBuffer.scala @@ -0,0 +1,119 @@ +package io.iohk.metronome.storage + +import cats.implicits._ +import scodec.{Decoder, Encoder, Codec} + +/** Storing the last N items inserted into a collection. + * + * This component is currently tested through `LedgerStorage`. + */ +class KVRingBuffer[N, K, V]( + coll: KVCollection[N, K, V], + metaNamespace: N, + maxHistorySize: Int +)(implicit codecK: Codec[K]) { + require(maxHistorySize > 0, "Has to store at least one item in the buffer.") + + import KVRingBuffer._ + import scodec.codecs.implicits.implicitIntCodec + + private implicit val kvn = KVStore.instance[N] + + private implicit val metaKeyEncoder: Encoder[MetaKey[_]] = { + import scodec.codecs._ + import scodec.codecs.implicits._ + + val bucketIndexCodec = provide(BucketIndex) + val bucketCodec: Codec[Bucket[_]] = Codec.deriveLabelledGeneric + val keyRefCountCodec: Codec[KeyRefCount[K]] = Codec.deriveLabelledGeneric + + discriminated[MetaKey[_]] + .by(uint2) + .typecase(0, bucketIndexCodec) + .typecase(1, bucketCodec) + .typecase(2, keyRefCountCodec) + .asEncoder + } + + private def getMetaData[V: Decoder](key: MetaKey[V]) = + KVStore[N].get[MetaKey[V], V](metaNamespace, key) + + private def putMetaData[V: Encoder](key: MetaKey[V], value: V) = + KVStore[N].put(metaNamespace, key, value) + + private def setRefCount(key: K, count: Int) = + if (count > 0) + putMetaData[Int](KeyRefCount(key), count) + else + KVStore[N].delete(metaNamespace, KeyRefCount(key)) + + private def getRefCount(key: K) = + getMetaData[Int](KeyRefCount(key)).map(_ getOrElse 0) + + /** Return the index of the next bucket to write the data into. */ + private def nextIndex(maybeIndex: Option[Int]): Int = + maybeIndex.fold(0)(index => (index + 1) % maxHistorySize) + + private def add(key: K, value: V) = + getRefCount(key).flatMap { cnt => + if (cnt == 0) + setRefCount(key, 1) >> coll.put(key, value) + else + setRefCount(key, cnt + 1) + } + + private def maybeRemove(key: K) = + getRefCount(key).flatMap { cnt => + if (cnt > 1) + setRefCount(key, cnt - 1).as(none[K]) + else + setRefCount(key, 0) >> coll.delete(key).as(key.some) + } + + /** Save a new item and remove the oldest one, if we reached + * the maximum history size. + * + * Returns the key which has been evicted, unless it's still + * referenced by something or the history hasn't reached maximum + * size yet. + */ + def put(key: K, value: V): KVStore[N, Option[K]] = { + for { + index <- getMetaData(BucketIndex).map(nextIndex) + maybeOldestKey <- getMetaData(Bucket[K](index)) + maybeRemoved <- maybeOldestKey match { + case Some(oldestKey) if oldestKey == key => + KVStore[N].pure(none[K]) + + case Some(oldestKey) => + add(key, value) >> maybeRemove(oldestKey) + + case None => + add(key, value).as(none[K]) + } + _ <- putMetaData(Bucket(index), key) + _ <- putMetaData(BucketIndex, index) + } yield maybeRemoved + } + + /** Retrieve an item by hash, if we still have it. */ + def get(key: K): KVStoreRead[N, Option[V]] = + coll.read(key) +} + +object KVRingBuffer { + + /** Keys for different pieces of meta-data stored under a single namespace. */ + sealed trait MetaKey[+V] + + /** Key under which the last written index of the ring buffer is stored. */ + case object BucketIndex extends MetaKey[Int] + + /** Contents of a ring buffer bucket by index. */ + case class Bucket[V](index: Int) extends MetaKey[V] { + assert(index >= 0) + } + + /** Number of buckets currently pointing at a key. */ + case class KeyRefCount[K](key: K) extends MetaKey[Int] +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVStore.scala b/metronome/storage/src/io/iohk/metronome/storage/KVStore.scala new file mode 100644 index 00000000..9a61383b --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVStore.scala @@ -0,0 +1,88 @@ +package io.iohk.metronome.storage + +import cats.{~>} +import cats.free.Free +import cats.free.Free.liftF +import scodec.{Encoder, Decoder, Codec} + +/** Helper methods to read/write a key-value store. */ +object KVStore { + + def unit[N]: KVStore[N, Unit] = + pure(()) + + def pure[N, A](a: A): KVStore[N, A] = + Free.pure(a) + + def instance[N]: Ops[N] = new Ops[N] {} + + def apply[N: Ops] = implicitly[Ops[N]] + + /** Scope all operations under the `N` type, which can be more convenient, + * e.g. `KVStore[String].pure(1)` instead of `KVStore.pure[String, Int](1)` + */ + trait Ops[N] { + import KVStoreOp._ + + type KVNamespacedOp[A] = ({ type L[A] = KVStoreOp[N, A] })#L[A] + type KVNamespacedReadOp[A] = ({ type L[A] = KVStoreReadOp[N, A] })#L[A] + + def unit: KVStore[N, Unit] = KVStore.unit[N] + + def pure[A](a: A) = KVStore.pure[N, A](a) + + /** Insert or replace a value under a key. */ + def put[K: Encoder, V: Encoder]( + namespace: N, + key: K, + value: V + ): KVStore[N, Unit] = + liftF[KVNamespacedOp, Unit]( + Put[N, K, V](namespace, key, value) + ) + + /** Get a value under a key, if it exists. */ + def get[K: Encoder, V: Decoder]( + namespace: N, + key: K + ): KVStore[N, Option[V]] = + liftF[KVNamespacedOp, Option[V]]( + Get[N, K, V](namespace, key) + ) + + /** Delete a value under a key. */ + def delete[K: Encoder](namespace: N, key: K): KVStore[N, Unit] = + liftF[KVNamespacedOp, Unit]( + Delete[N, K](namespace, key) + ) + + /** Apply a function on a value, if it exists. */ + def update[K: Encoder, V: Codec](namespace: N, key: K)( + f: V => V + ): KVStore[N, Unit] = + alter[K, V](namespace, key)(_ map f) + + /** Insert, update or delete a value, depending on whether it exists. */ + def alter[K: Encoder, V: Codec](namespace: N, key: K)( + f: Option[V] => Option[V] + ): KVStore[N, Unit] = + get[K, V](namespace, key).flatMap { current => + (current, f(current)) match { + case ((None, None)) => unit + case ((Some(existing), Some(value))) if existing == value => unit + case (_, Some(value)) => put(namespace, key, value) + case (Some(_), None) => delete(namespace, key) + } + } + + /** Lift a read-only operation into a read-write one, so that we can chain them together. */ + def lift[A](read: KVStoreRead[N, A]): KVStore[N, A] = + read.mapK(liftCompiler) + + private val liftCompiler: KVNamespacedReadOp ~> KVNamespacedOp = + new (KVNamespacedReadOp ~> KVNamespacedOp) { + def apply[A](fa: KVNamespacedReadOp[A]): KVNamespacedOp[A] = + fa + } + } +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVStoreOp.scala b/metronome/storage/src/io/iohk/metronome/storage/KVStoreOp.scala new file mode 100644 index 00000000..1e929092 --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVStoreOp.scala @@ -0,0 +1,35 @@ +package io.iohk.metronome.storage + +import scodec.{Encoder, Decoder} + +/** Representing key-value storage operations as a Free Monad, + * so that we can pick an execution strategy that best fits + * the database technology at hand: + * - execute multiple writes atomically by batching + * - execute all reads and writes in a transaction + * + * The key-value store is expected to store binary data, + * so a scodec.Codec is required for all operations to + * serialize the keys and the values. + * + * https://typelevel.org/cats/datatypes/freemonad.html + */ +sealed trait KVStoreOp[N, A] +sealed trait KVStoreReadOp[N, A] extends KVStoreOp[N, A] +sealed trait KVStoreWriteOp[N, A] extends KVStoreOp[N, A] + +object KVStoreOp { + case class Put[N, K, V](namespace: N, key: K, value: V)(implicit + val keyEncoder: Encoder[K], + val valueEncoder: Encoder[V] + ) extends KVStoreWriteOp[N, Unit] + + case class Get[N, K, V](namespace: N, key: K)(implicit + val keyEncoder: Encoder[K], + val valueDecoder: Decoder[V] + ) extends KVStoreReadOp[N, Option[V]] + + case class Delete[N, K](namespace: N, key: K)(implicit + val keyEncoder: Encoder[K] + ) extends KVStoreWriteOp[N, Unit] +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVStoreRead.scala b/metronome/storage/src/io/iohk/metronome/storage/KVStoreRead.scala new file mode 100644 index 00000000..000db508 --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVStoreRead.scala @@ -0,0 +1,40 @@ +package io.iohk.metronome.storage + +import cats.free.Free +import cats.free.Free.liftF +import scodec.{Encoder, Decoder} + +/** Helper methods to compose operations that strictly only do reads, no writes. + * + * Basically the same as `KVStore` without `put` and `delete`. + */ +object KVStoreRead { + + def unit[N]: KVStoreRead[N, Unit] = + pure(()) + + def pure[N, A](a: A): KVStoreRead[N, A] = + Free.pure(a) + + def instance[N]: Ops[N] = new Ops[N] {} + + def apply[N: Ops] = implicitly[Ops[N]] + + trait Ops[N] { + import KVStoreOp._ + + type KVNamespacedOp[A] = ({ type L[A] = KVStoreReadOp[N, A] })#L[A] + + def unit: KVStoreRead[N, Unit] = KVStoreRead.unit[N] + + def pure[A](a: A) = KVStoreRead.pure[N, A](a) + + def read[K: Encoder, V: Decoder]( + namespace: N, + key: K + ): KVStoreRead[N, Option[V]] = + liftF[KVNamespacedOp, Option[V]]( + Get[N, K, V](namespace, key) + ) + } +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVStoreRunner.scala b/metronome/storage/src/io/iohk/metronome/storage/KVStoreRunner.scala new file mode 100644 index 00000000..7372062b --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVStoreRunner.scala @@ -0,0 +1,7 @@ +package io.iohk.metronome.storage + +/** Convenience interface to turn KVStore queries into effects. */ +trait KVStoreRunner[F[_], N] { + def runReadOnly[A](query: KVStoreRead[N, A]): F[A] + def runReadWrite[A](query: KVStore[N, A]): F[A] +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVStoreState.scala b/metronome/storage/src/io/iohk/metronome/storage/KVStoreState.scala new file mode 100644 index 00000000..67cee1d9 --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVStoreState.scala @@ -0,0 +1,83 @@ +package io.iohk.metronome.storage + +import cats.{~>} +import cats.data.{State, Reader} +import io.iohk.metronome.storage.KVStoreOp.{Put, Get, Delete} + +/** A pure implementation of the Free interpreter using the State monad. + * + * It uses a specific namespace type, which is common to all collections. + */ +class KVStoreState[N] { + + // Ignoring the Codec for the in-memory use case. + type Store = Map[N, Map[Any, Any]] + // Type aliases to support the `~>` transformation with types that + // only have 1 generic type argument `A`. + type KVNamespacedState[A] = State[Store, A] + type KVNamespacedOp[A] = ({ type L[A] = KVStoreOp[N, A] })#L[A] + + type KVNamespacedReader[A] = Reader[Store, A] + type KVNamespacedReadOp[A] = ({ type L[A] = KVStoreReadOp[N, A] })#L[A] + + private val stateCompiler: KVNamespacedOp ~> KVNamespacedState = + new (KVNamespacedOp ~> KVNamespacedState) { + def apply[A](fa: KVNamespacedOp[A]): KVNamespacedState[A] = + fa match { + case Put(n, k, v) => + State.modify { nkvs => + val kvs = nkvs.getOrElse(n, Map.empty) + nkvs.updated(n, kvs.updated(k, v)) + } + + case Get(n, k) => + State.inspect { nkvs => + for { + kvs <- nkvs.get(n) + v <- kvs.get(k) + // NOTE: This should be fine as long as we access it through + // `KVCollection` which works with 1 kind of value; + // otherwise we could change the effect to allow errors: + // `State[Store, Either[Throwable, A]]` + + // The following cast would work but it's not required: + // .asInstanceOf[A] + } yield v + } + + case Delete(n, k) => + State.modify { nkvs => + val kvs = nkvs.getOrElse(n, Map.empty) - k + if (kvs.isEmpty) nkvs - n else nkvs.updated(n, kvs) + } + } + } + + private val readerCompiler: KVNamespacedReadOp ~> KVNamespacedReader = + new (KVNamespacedReadOp ~> KVNamespacedReader) { + def apply[A](fa: KVNamespacedReadOp[A]): KVNamespacedReader[A] = + fa match { + case Get(n, k) => + Reader { nkvs => + for { + kvs <- nkvs.get(n) + v <- kvs.get(k) + } yield v + } + } + } + + /** Compile a KVStore program to a State monad, which can be executed like: + * + * `new KvStoreState[String].compile(program).run(Map.empty).value` + */ + def compile[A](program: KVStore[N, A]): KVNamespacedState[A] = + program.foldMap(stateCompiler) + + /** Compile a KVStore program to a Reader monad, which can be executed like: + * + * `new KvStoreState[String].compile(program).run(Map.empty)` + */ + def compile[A](program: KVStoreRead[N, A]): KVNamespacedReader[A] = + program.foldMap(readerCompiler) +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/KVTree.scala b/metronome/storage/src/io/iohk/metronome/storage/KVTree.scala new file mode 100644 index 00000000..a3a9d91b --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/KVTree.scala @@ -0,0 +1,293 @@ +package io.iohk.metronome.storage + +import cats.implicits._ +import scala.collection.immutable.Queue + +/** Storage for nodes that maintains parent-child relationships as well, + * to facilitate tree traversal and pruning. + * + * It is assumed that the application maintains some pointers into the tree + * where it can start traversing from, e.g. the last Commit Quorum Certificate + * would point at a block hash which would serve as the entry point. + * + * This component is currently tested through `BlockStorage`. + */ +class KVTree[N, K, V]( + nodeColl: KVCollection[N, K, V], + nodeMetaColl: KVCollection[N, K, KVTree.NodeMeta[K]], + parentToChildrenColl: KVCollection[N, K, Set[K]] +)(implicit ev: KVTree.Node[K, V]) { + import KVTree.NodeMeta + + private implicit val kvn = KVStore.instance[N] + private implicit val kvrn = KVStoreRead.instance[N] + + /** Insert a node into the store, and if the parent still exists, + * then add this node to its children. + */ + def put(value: V): KVStore[N, Unit] = { + val nodeKey = ev.key(value) + val meta = + NodeMeta(ev.parentKey(value), ev.height(value)) + + nodeColl.put(nodeKey, value) >> + nodeMetaColl.put(nodeKey, meta) >> + parentToChildrenColl.alter(meta.parentKey) { maybeChildren => + maybeChildren orElse Set.empty.some map (_ + nodeKey) + } + + } + + /** Retrieve a node by key, if it exists. */ + def get(key: K): KVStoreRead[N, Option[V]] = + nodeColl.read(key) + + /** Check whether a node is present in the tree. */ + def contains(key: K): KVStoreRead[N, Boolean] = + nodeMetaColl.read(key).map(_.isDefined) + + /** Check how many children the node has in the tree. */ + private def childCount(key: K): KVStoreRead[N, Int] = + parentToChildrenColl.read(key).map(_.fold(0)(_.size)) + + /** Check whether the parent of the node is present in the tree. */ + private def hasParent(key: K): KVStoreRead[N, Boolean] = + nodeMetaColl.read(key).flatMap { + case None => KVStoreRead[N].pure(false) + case Some(meta) => contains(meta.parentKey) + } + + private def getParentKey( + key: K + ): KVStoreRead[N, Option[K]] = + nodeMetaColl.read(key).map(_.map(_.parentKey)) + + /** Check whether it's safe to delete a node. + * + * A node is safe to delete if doing so doesn't break up the tree + * into a forest, in which case we may have nodes we cannot reach + * by traversal, leaking space. + * + * This is true if the node has no children, + * or it has no parent and at most one child. + */ + private def canDelete(key: K): KVStoreRead[N, Boolean] = + (hasParent(key), childCount(key)).mapN { + case (_, 0) => true + case (false, 1) => true + case _ => false + } + + /** Delete a node by hash, if doing so wouldn't break the tree; + * otherwise do nothing. + * + * Return `true` if node has been deleted, `false` if not. + * + * If this is not efficent enough, then move the deletion traversal + * logic into the this class so it can make sure all the invariants + * are maintained, e.g. collect all hashes that can be safely deleted + * and then do so without checks. + */ + def delete(key: K): KVStore[N, Boolean] = + canDelete(key).lift.flatMap { ok => + deleteUnsafe(key).whenA(ok).as(ok) + } + + /** Delete a node and remove it from any parent-to-child mapping, + * without any checking for the tree structure invariants. + */ + def deleteUnsafe(key: K): KVStore[N, Unit] = { + def deleteIfEmpty(maybeChildren: Option[Set[K]]) = + maybeChildren.filter(_.nonEmpty) + + getParentKey(key).lift.flatMap { + case None => + KVStore[N].unit + case Some(parentKey) => + parentToChildrenColl.alter(parentKey) { maybeChildren => + deleteIfEmpty(maybeChildren.map(_ - key)) + } + } >> + nodeColl.delete(key) >> + nodeMetaColl.delete(key) >> + // Keep the association from existing children, until they last one is deleted. + parentToChildrenColl.alter(key)(deleteIfEmpty) + } + + /** Get the ancestor chain of a node from the root, including the node itself. + * + * If the node is not in the tree, the result will be empty, + * otherwise `head` will be the root of the node tree, + * and `last` will be the node itself. + */ + def getPathFromRoot(key: K): KVStoreRead[N, List[K]] = { + def loop( + key: K, + acc: List[K] + ): KVStoreRead[N, List[K]] = { + getParentKey(key).flatMap { + case None => + // This node doesn't exist in the tree, so our ancestry is whatever we collected so far. + KVStoreRead[N].pure(acc) + + case Some(parentKey) => + // So at least `key` exists in the tree. + loop(parentKey, key :: acc) + } + } + loop(key, Nil) + } + + /** Get the ancestor chain between two hashes in the chain, if there is one. + * + * If either of the nodes are not in the tree, or there's no path between them, + * return an empty list. This can happen if we have already pruned away the ancestry as well. + */ + def getPathFromAncestor( + ancestorKey: K, + descendantKey: K + ): KVStoreRead[N, List[K]] = { + def loop( + key: K, + acc: List[K], + maxDistance: Long + ): KVStoreRead[N, List[K]] = { + if (key == ancestorKey) { + KVStoreRead[N].pure(key :: acc) + } else if (maxDistance == 0) { + KVStoreRead[N].pure(Nil) + } else { + nodeMetaColl.read(key).flatMap { + case None => + KVStoreRead[N].pure(Nil) + case Some(meta) => + loop(meta.parentKey, key :: acc, maxDistance - 1) + } + } + } + + ( + nodeMetaColl.read(ancestorKey), + nodeMetaColl.read(descendantKey) + ).mapN((_, _)) + .flatMap { + case (Some(ameta), Some(dmeta)) => + loop( + descendantKey, + Nil, + maxDistance = dmeta.height - ameta.height + ) + case _ => KVStoreRead[N].pure(Nil) + } + } + + /** Collect all descendants of a node, including the node itself. + * + * The result will start with the nodes furthest away, + * so it should be safe to delete them in the same order; + * `last` will be the node itself. + * + * The `skip` parameter can be used to avoid traversing + * branches that we want to keep during deletion. + */ + def getDescendants( + key: K, + skip: Set[K] = Set.empty + ): KVStoreRead[N, List[K]] = { + // BFS traversal. + def loop( + queue: Queue[K], + acc: List[K] + ): KVStoreRead[N, List[K]] = { + queue.dequeueOption match { + case None => + KVStoreRead[N].pure(acc) + + case Some((key, queue)) if skip(key) => + loop(queue, acc) + + case Some((key, queue)) => + parentToChildrenColl.read(key).flatMap { + case None => + // Since we're not inserting an empty child set, + // we can't tell here if the node exists or not. + loop(queue, key :: acc) + case Some(children) => + loop(queue ++ children, key :: acc) + } + } + } + + loop(Queue(key), Nil).flatMap { + case result @ List(`key`) => + result.filterA(contains) + case result => + KVStoreRead[N].pure(result) + } + } + + /** Delete all nodes which are not descendants of a given node, making it the new root. + * + * Return the list of deleted node keys. + */ + def pruneNonDescendants(key: K): KVStore[N, List[K]] = + getPathFromRoot(key).lift.flatMap { + case Nil => + KVStore[N].pure(Nil) + + case path @ (rootHash :: _) => + // The safe order to delete nodes would be to go down the main chain + // from the root, delete each non-mainchain child, then the parent, + // then descend on the main chain until we hit `key`. + + // A similar effect can be achieved by collecting all descendants + // of the root, then deleting everything that isn't on the main chain, + // from the children towards the root, and finally the main chain itself, + // going from the root towards the children. + val isMainChain = path.toSet + + for { + deleteables <- getDescendants(rootHash, skip = Set(key)).lift + _ <- deleteables.filterNot(isMainChain).traverse(deleteUnsafe(_)) + _ <- path.init.traverse(deleteUnsafe(_)) + } yield deleteables + } + + /** Remove all nodes in a tree, given by a key that's in the tree, + * except perhaps a new root (and its descendants) we want to keep. + * + * This is used to delete an old tree when starting a new that's most likely + * not connected to it, and would otherwise result in a forest. + */ + def purgeTree( + key: K, + keep: Option[K] + ): KVStore[N, List[K]] = + getPathFromRoot(key).lift.flatMap { + case Nil => + KVStore[N].pure(Nil) + + case rootHash :: _ => + for { + ds <- getDescendants(rootHash, skip = keep.toSet).lift + // Going from the leaves towards the root. + _ <- ds.reverse.traverse(deleteUnsafe(_)) + } yield ds + } +} + +object KVTree { + + /** Type class for the node-like values stored in the tree. */ + trait Node[K, V] { + def key(value: V): K + def parentKey(value: V): K + def height(value: V): Long + } + + /** Properties about the nodes that we frequently need for traversal. */ + case class NodeMeta[K]( + parentKey: K, + height: Long + ) +} diff --git a/metronome/storage/src/io/iohk/metronome/storage/package.scala b/metronome/storage/src/io/iohk/metronome/storage/package.scala new file mode 100644 index 00000000..c340dc80 --- /dev/null +++ b/metronome/storage/src/io/iohk/metronome/storage/package.scala @@ -0,0 +1,18 @@ +package io.iohk.metronome + +import cats.free.Free + +package object storage { + + /** Read/Write operations over a key-value store. */ + type KVStore[N, A] = Free[({ type L[A] = KVStoreOp[N, A] })#L, A] + + /** Read-only operations over a key-value store. */ + type KVStoreRead[N, A] = Free[({ type L[A] = KVStoreReadOp[N, A] })#L, A] + + /** Extension method to lift a read-only operation to read-write. */ + implicit class KVStoreReadOps[N, A](val read: KVStoreRead[N, A]) + extends AnyVal { + def lift: KVStore[N, A] = KVStore.instance[N].lift(read) + } +} diff --git a/metronome/storage/test/src/io/iohk/metronome/storage/KVStoreStateSpec.scala b/metronome/storage/test/src/io/iohk/metronome/storage/KVStoreStateSpec.scala new file mode 100644 index 00000000..4778e4d2 --- /dev/null +++ b/metronome/storage/test/src/io/iohk/metronome/storage/KVStoreStateSpec.scala @@ -0,0 +1,45 @@ +package io.iohk.metronome.storage + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import scodec.codecs.implicits._ + +class KVStoreStateSpec extends AnyFlatSpec with Matchers { + import KVStoreStateSpec._ + + behavior of "KVStoreState" + + it should "compose multiple collections" in { + type Namespace = String + // Two independent collections with different types of keys and values. + val collA = new KVCollection[Namespace, Int, RecordA](namespace = "a") + val collB = new KVCollection[Namespace, String, RecordB](namespace = "b") + + val program: KVStore[Namespace, Option[RecordA]] = for { + _ <- collA.put(1, RecordA("one")) + _ <- collB.put("two", RecordB(2)) + b <- collB.get("three") + _ <- collB.put("three", RecordB(3)) + _ <- collB.delete("two") + _ <- + if (b.isEmpty) collA.put(4, RecordA("four")) + else KVStore.unit[Namespace] + a <- collA.read(1).lift + } yield a + + val compiler = new KVStoreState[Namespace] + + val (store, maybeA) = compiler.compile(program).run(Map.empty).value + + maybeA shouldBe Some(RecordA("one")) + store shouldBe Map( + "a" -> Map(1 -> RecordA("one"), 4 -> RecordA("four")), + "b" -> Map("three" -> RecordB(3)) + ) + } +} + +object KVStoreStateSpec { + case class RecordA(a: String) + case class RecordB(b: Int) +} diff --git a/metronome/tracing/src/io/iohk/metronome/tracer/Tracer.scala b/metronome/tracing/src/io/iohk/metronome/tracer/Tracer.scala new file mode 100644 index 00000000..09bb2ace --- /dev/null +++ b/metronome/tracing/src/io/iohk/metronome/tracer/Tracer.scala @@ -0,0 +1,147 @@ +package io.iohk.metronome.tracer + +import language.higherKinds +import cats.{Applicative, Contravariant, FlatMap, Id, Monad, Monoid, Show, ~>} + +/** Contravariant tracer. + * + * Ported from https://github.com/input-output-hk/contra-tracer/blob/master/src/Control/Tracer.hs + */ +trait Tracer[F[_], -A] { + def apply(a: => A): F[Unit] +} + +object Tracer { + + def instance[F[_], A](f: (=> A) => F[Unit]): Tracer[F, A] = + new Tracer[F, A] { + override def apply(a: => A): F[Unit] = f(a) + } + + def const[F[_], A](f: F[Unit]): Tracer[F, A] = + instance(_ => f) + + /** If you know: + * - how to enrich type A that is traced + * - how to squeeze B's to create A's (possibly enrich B with extra stuff, or forget some details) + * then you have Tracer for B + * + * Example + * ``` + * import cats.syntax.contravariant._ + * + * val atracer: Tracer[F, A] = ??? + * val btracer: Tracer[F, B] = atracer.contramap[B](b => b.toA) + * ```. + */ + implicit def contraTracer[F[_]]: Contravariant[Tracer[F, *]] = + new Contravariant[Tracer[F, *]] { + override def contramap[A, B](fa: Tracer[F, A])(f: B => A): Tracer[F, B] = + new Tracer[F, B] { + override def apply(a: => B): F[Unit] = fa(f(a)) + } + } + + def noOpTracer[M[_], A](implicit MA: Applicative[M]): Tracer[M, A] = + new Tracer[M, A] { + override def apply(a: => A): M[Unit] = MA.pure(()) + } + + implicit def monoidTracer[F[_], S](implicit + MA: Applicative[F] + ): Monoid[Tracer[F, S]] = + new Monoid[Tracer[F, S]] { + + /** Run sequentially two tracers */ + override def combine(a1: Tracer[F, S], a2: Tracer[F, S]): Tracer[F, S] = + s => MA.productR(a1(s))(a2(s)) + + override def empty: Tracer[F, S] = noOpTracer + } + + /** Trace value a using tracer tracer */ + def traceWith[F[_], A](tracer: Tracer[F, A], a: A): F[Unit] = tracer(a) + + /** contravariant Kleisli composition: + * if you can: + * - produce effect M[B] from A + * - trace B's + * then you can trace A's + */ + def contramapM[F[_], A, B](f: A => F[B], tracer: Tracer[F, B])(implicit + MM: FlatMap[F] + ): Tracer[F, A] = { + new Tracer[F, A] { + override def apply(a: => A): F[Unit] = + MM.flatMap(f(a))(tracer(_)) + } + } + + /** change the effect F to G using natural transformation nat */ + def natTracer[F[_], G[_], A]( + nat: F ~> G, + tracer: Tracer[F, A] + ): Tracer[G, A] = + a => nat(tracer(a)) + + /** filter out values to trace if they do not pass predicate p */ + def condTracing[F[_], A](p: A => Boolean, tr: Tracer[F, A])(implicit + FM: Applicative[F] + ): Tracer[F, A] = { + new Tracer[F, A] { + override def apply(a: => A): F[Unit] = + if (p(a)) tr(a) + else FM.pure(()) + } + } + + /** filter out values that was send to trace using side effecting predicate */ + def condTracingM[F[_], A](p: F[A => Boolean], tr: Tracer[F, A])(implicit + FM: Monad[F] + ): Tracer[F, A] = + a => + FM.flatMap(p) { pa => + if (pa(a)) tr(a) + else FM.pure(()) + } + + def showTracing[F[_], A]( + tracer: Tracer[F, String] + )(implicit S: Show[A], C: Contravariant[Tracer[F, *]]): Tracer[F, A] = + C.contramap(tracer)(S.show) + + def traceAll[A, B](f: B => List[A], t: Tracer[Id, A]): Tracer[Id, B] = + new Tracer[Id, B] { + override def apply(event: => B): Id[Unit] = f(event).foreach(t(_)) + } +} + +object TracerSyntax { + + implicit class TracerOps[F[_], A](val tracer: Tracer[F, A]) extends AnyVal { + + /** Trace value a using tracer tracer */ + def trace(a: A): F[Unit] = tracer(a) + + /** contravariant Kleisli composition: + * if you can: + * - produce effect M[B] from A + * - trace B's + * then you can trace A's + */ + def >=>[B](f: B => F[A])(implicit MM: FlatMap[F]): Tracer[F, B] = + Tracer.contramapM(f, tracer) + + def nat[G[_]](nat: F ~> G): Tracer[G, A] = + Tracer.natTracer(nat, tracer) + + def filter(p: A => Boolean)(implicit FM: Applicative[F]): Tracer[F, A] = + Tracer.condTracing[F, A](p, tracer) + + def filterNot(p: A => Boolean)(implicit FM: Applicative[F]): Tracer[F, A] = + filter(a => !p(a)) + + def filterM(p: F[A => Boolean])(implicit FM: Monad[F]): Tracer[F, A] = + Tracer.condTracingM(p, tracer) + } +} diff --git a/versionFile/version b/versionFile/version new file mode 100644 index 00000000..4ecb6644 --- /dev/null +++ b/versionFile/version @@ -0,0 +1 @@ +0.1.0-SNAPSHOT \ No newline at end of file