diff --git a/README.md b/README.md index db4cd66..17c8e78 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Official GitHub repo: https://github.com/scalawithcats/scala-with-cats 1. [Introduction](src/main/scala/ch01) 2. [Monoids and Semigroups](src/main/scala/ch02) 3. [Functors](src/main/scala/ch03) +4. [Monads](src/main/scala/ch04) ## Running tests diff --git a/src/main/scala/ch04/Eval.scala b/src/main/scala/ch04/Eval.scala new file mode 100644 index 0000000..42d7407 --- /dev/null +++ b/src/main/scala/ch04/Eval.scala @@ -0,0 +1,16 @@ +package ch04 + +import cats.Eval as CatsEval + +/* +4.6.5 Exercise: Safer Folding using Eval +The naive implementation of foldRight below is not stack safe. Make it so using Eval. + */ +object Eval: + def foldRight[A, B](as: List[A], acc: B)(fn: (A, B) => B): B = + def foldR(xs: List[A]): CatsEval[B] = + xs match + case head :: tail => CatsEval.defer(foldR(tail).map(fn(head, _))) + case Nil => CatsEval.now(acc) + + foldR(as).value diff --git a/src/main/scala/ch04/Monad.scala b/src/main/scala/ch04/Monad.scala new file mode 100644 index 0000000..e7660a4 --- /dev/null +++ b/src/main/scala/ch04/Monad.scala @@ -0,0 +1,28 @@ +package ch04 + +trait Monad[F[_]]: + def pure[A](a: A): F[A] + + def flatMap[A, B](value: F[A])(f: A => F[B]): F[B] + + /* + 4.1.2 Exercise: Getting Func-y + Every monad is also a functor. We can define map in the same way + for every monad using the existing methods, flatMap and pure. + Try defining map yourself now. + */ + def map[A, B](value: F[A])(f: A => B): F[B] = + flatMap(value)(a => pure(f(a))) + + object MonadInstances: + type Id[A] = A + + /* + 4.3.1 Exercise: Monadic Secret Identities + Implement pure, map, and flatMap for Id! + */ + given idMonad: Monad[Id] with + def pure[A](a: A): Id[A] = a + + def flatMap[A, B](value: Id[A])(f: A => Id[B]): Id[B] = + f(value) diff --git a/src/main/scala/ch04/MonadError.scala b/src/main/scala/ch04/MonadError.scala new file mode 100644 index 0000000..55192dd --- /dev/null +++ b/src/main/scala/ch04/MonadError.scala @@ -0,0 +1,19 @@ +package ch04 + +import cats.{Monad, MonadError as CatsMonadError} +import cats.syntax.applicativeError.catsSyntaxApplicativeErrorId + +/* +4.5.4 Exercise: Abstracting +Implement a method validateAdult with the following signature + +def validateAdult[F[_]](age: Int)(implicit me: MonadError[F, Throwable]): F[Int] + +When passed an age greater than or equal to 18 it should return that value as a success. +Otherwise it should return a error represented as an IllegalArgumentException. + */ +object MonadError: + def validateAdult[F[_]](age: Int)(implicit me: CatsMonadError[F, Throwable]): F[Int] = + if age >= 18 + then Monad[F].pure(age) + else new IllegalArgumentException("Age must be greater than or equal to 18").raiseError[F, Int] diff --git a/src/main/scala/ch04/Reader.scala b/src/main/scala/ch04/Reader.scala new file mode 100644 index 0000000..c15b319 --- /dev/null +++ b/src/main/scala/ch04/Reader.scala @@ -0,0 +1,40 @@ +package ch04 + +import cats.data.Reader as CatsReader +import cats.syntax.applicative.catsSyntaxApplicativeId + +/* +4.8.3 Exercise: Hacking on Readers + +The classic use of Readers is to build programs that accept a configuration as a parameter. +Let's ground this with a complete example of a simple login system. +Our configuration will consist of two databases: a list of valid users and a list of their passwords. + +Start by creating a type alias DbReader for a Reader that consumes a Db as input. + +Now create methods that generate DbReaders to look up the username for an Int user ID, +and look up the password for a String username. + +Finally create a checkLogin method to check the password for a given user ID. + */ +object Reader: + final case class Db( + usernames: Map[Int, String], + passwords: Map[String, String] + ) + + type DbReader[A] = CatsReader[Db, A] + + def findUsername(userId: Int): DbReader[Option[String]] = + CatsReader(_.usernames.get(userId)) + + def checkPassword(username: String, password: String): DbReader[Boolean] = + CatsReader(_.passwords.get(username).contains(password)) + + def checkLogin(userId: Int, password: String): DbReader[Boolean] = + for { + username <- findUsername(userId) + passwordOk <- username + .map(checkPassword(_, password)) + .getOrElse(false.pure[DbReader]) + } yield passwordOk diff --git a/src/main/scala/ch04/State.scala b/src/main/scala/ch04/State.scala new file mode 100644 index 0000000..135b9d0 --- /dev/null +++ b/src/main/scala/ch04/State.scala @@ -0,0 +1,58 @@ +package ch04 + +import cats.data.State as CatsState +import cats.syntax.applicative.catsSyntaxApplicativeId +import cats.syntax.apply.catsSyntaxApplyOps + +/* +4.9.3 Exercise: Post-Order Calculator +Let's write an interpreter for post-order expressions. +We can parse each symbol into a State instance representing +a transformation on the stack and an intermediate result. + +Start by writing a function evalOne that parses a single symbol into an instance of State. +If the stack is in the wrong configuration, it's OK to throw an exception. + */ +object State: + type Stack = List[Int] + type CalcState[A] = CatsState[Stack, A] + + def eval(sym: String, s: Stack): Stack = + s match + case x :: y :: s1 => + sym match + case "+" => x + y :: s1 + case "-" => y - x :: s1 + case "*" => x * y :: s1 + case "/" if x != 0 => y / x :: s1 + case "/" => sys.error("divide by zero") + case _ => sys.error("bad expression") + + def evalOne(sym: String): CalcState[Int] = + for + s <- CatsState.get[Stack] + s1 = sym match + case x if x.forall(Character.isDigit) => x.toInt :: s + case x => eval(sym, s) + + _ <- CatsState.set[Stack](s1) + yield s1.head + + /* + Generalise this example by writing an evalAll method that computes the result of a List[String]. + Use evalOne to process each symbol, and thread the resulting State monads together using flatMap. + */ + def evalAll(input: List[String]): CalcState[Int] = + input.foldLeft(0.pure[CalcState]) { (s, x) => + // We discard the value, but must use the previous + // state for the next computation. + // Simply invoking evalOne will create a new state. + s *> evalOne(x) + } + + /* + Complete the exercise by implementing an evalInput function that splits an input String into symbols, + calls evalAll, and runs the result with an initial stack. + */ + def evalInput(input: String): Int = + evalAll(input.split(" ").toList).runA(Nil).value diff --git a/src/main/scala/ch04/Tree.scala b/src/main/scala/ch04/Tree.scala new file mode 100644 index 0000000..d909907 --- /dev/null +++ b/src/main/scala/ch04/Tree.scala @@ -0,0 +1,43 @@ +package ch04 + +import cats.Monad as CatsMonad + +/* +4.10.1 Exercise: Branching out Further with Monads + +Let's write a Monad for the Tree data type given below. + +Verify that the code works on instances of Branch and Leaf, +and that the Monad provides Functor-like behaviour for free. + +Also verify that having a Monad in scope allows us to use for comprehensions, +despite the fact that we haven’t directly implemented flatMap or map on Tree. + +Don't feel you have to make tailRecM tail-recursive. Doing so is quite difficult. + */ +sealed trait Tree[+A] + +final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A] + +final case class Leaf[A](value: A) extends Tree[A] + +def branch[A](left: Tree[A], right: Tree[A]): Tree[A] = + Branch(left, right) + +def leaf[A](value: A): Tree[A] = + Leaf(value) + +given CatsMonad[Tree] with + override def pure[A](x: A): Tree[A] = + Leaf(x) + + override def flatMap[A, B](t: Tree[A])(f: A => Tree[B]): Tree[B] = + t match + case Leaf(x) => f(x) + case Branch(l, r) => Branch(flatMap(l)(f), flatMap(r)(f)) + + // Not stack-safe! + override def tailRecM[A, B](a: A)(f: A => Tree[Either[A, B]]): Tree[B] = + flatMap(f(a)): + case Left(value) => tailRecM(value)(f) + case Right(value) => Leaf(value) diff --git a/src/main/scala/ch04/Writer.scala b/src/main/scala/ch04/Writer.scala new file mode 100644 index 0000000..bc2cf11 --- /dev/null +++ b/src/main/scala/ch04/Writer.scala @@ -0,0 +1,27 @@ +package ch04 + +import cats.data.Writer as CatsWriter +import cats.syntax.applicative.catsSyntaxApplicativeId +import cats.syntax.writer.catsSyntaxWriterId + +/* +4.7.3 Exercise: Show Your Working + +Rewrite factorial so it captures the log messages in a Writer. +Demonstrate that this allows us to reliably separate the logs for concurrent computations. + */ +object Writer: + def slowly[A](body: => A): A = + try body + finally Thread.sleep(100) + + type Logged[A] = CatsWriter[Vector[String], A] + + def factorial(n: Int): Logged[Int] = + for + ans <- + if (n == 0) + then 1.pure[Logged] + else slowly(factorial(n - 1).map(_ * n)) + _ <- Vector(s"fact $n $ans").tell + yield ans diff --git a/src/main/scala/ch04/ch04.worksheet.sc b/src/main/scala/ch04/ch04.worksheet.sc new file mode 100644 index 0000000..ab37482 --- /dev/null +++ b/src/main/scala/ch04/ch04.worksheet.sc @@ -0,0 +1,146 @@ +import cats.Eval +import cats.data.{Reader, Writer, State} +import cats.syntax.applicative.catsSyntaxApplicativeId +import cats.syntax.writer.catsSyntaxWriterId + +// ---------------------------------------------------------------------------- +// Eval +// ---------------------------------------------------------------------------- + +// call-by-value which is eager and memoized +val now = Eval.now(math.random + 1000) +// call-by-name which is lazy and not memoized +val always = Eval.always(math.random + 3000) +// call-by-need which is lazy and memoized +val later = Eval.later(math.random + 2000) + +now.value +always.value +later.value + +val greeting = Eval + .always{ println("Step 1"); "Hello" } + .map{ str => println("Step 2"); s"$str world" } + +greeting.value +// Step 1 +// Step 2 +// res16: String = "Hello world" + +val ans = for { + a <- Eval.now{ println("Calculating A"); 40 } + b <- Eval.always{ println("Calculating B"); 2 } +} yield { + println("Adding A and B") + a + b +} + +ans.value // first access +// Calculating B +// Adding A and B +// res17: Int = 42 // first access +ans.value // second access +// Calculating B +// Adding A and B +// res18: Int = 42 + +val saying = Eval + .always{ println("Step 1"); "The cat" } + .map{ str => println("Step 2"); s"$str sat on" } + .memoize + .map{ str => println("Step 3"); s"$str the mat" } + +saying.value // first access +// Step 1 +// Step 2 +// Step 3 +// res19: String = "The cat sat on the mat" // first access +saying.value // second access +// Step 3 +// res20: String = "The cat sat on the mat" + +// stack-safe +def factorial(n: BigInt): Eval[BigInt] = + if(n == 1) { + Eval.now(n) + } else { + Eval.defer(factorial(n - 1).map(_ * n)) + } + +factorial(50000).value + +// stack-safe foldRight +ch04.Eval.foldRight((1 to 100000).toList, 0L)(_ + _) + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- +type Logged[A] = Writer[Vector[String], A] + +123.pure[Logged] + +Vector("msg1", "msg2", "msg3").tell + +val b = 123.writer(Vector("msg1", "msg2", "msg3")) + +val writer1 = for { + a <- 10.pure[Logged] + _ <- Vector("a", "b", "c").tell + b <- 32.writer(Vector("x", "y", "z")) +} yield a + b + +writer1.run + +val writer2 = writer1.mapWritten(_.map(_.toUpperCase)) + +writer2.run + +val writer3 = writer1.bimap( + log => log.map(_.toUpperCase), + res => res * 100 +) + +writer3.run + +val writer5 = writer1.reset + +writer5.run + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- +final case class Cat(name: String, favoriteFood: String) + +val catName: Reader[Cat, String] = + Reader(cat => cat.name) + +val greetKitty: Reader[Cat, String] = + catName.map(name => s"Hello ${name}") + +val feedKitty: Reader[Cat, String] = + Reader(cat => s"Have a nice bowl of ${cat.favoriteFood}") + +val greetAndFeed: Reader[Cat, String] = + for { + greet <- greetKitty + feed <- feedKitty + } yield s"$greet. $feed." + +greetAndFeed(Cat("Garfield", "lasagne")) + +// ---------------------------------------------------------------------------- +// State +// ---------------------------------------------------------------------------- + +val a = State[Int, String] { state => + (state, s"The state is $state") +} + +// Get the state and the result +val (state, result) = a.run(10).value + +// Get the state, ignore the result +val justTheState = a.runS(10).value + +// Get the result, ignore the state +val justTheResult = a.runA(10).value diff --git a/src/test/scala/ch01/eq/CatSpec.scala b/src/test/scala/ch01/eq/CatSpec.scala index d2dc0ca..3e48b1b 100644 --- a/src/test/scala/ch01/eq/CatSpec.scala +++ b/src/test/scala/ch01/eq/CatSpec.scala @@ -5,16 +5,15 @@ import org.scalatest.matchers.should.Matchers.shouldBe import cats.syntax.eq.catsSyntaxEq class CatSpec extends AnyFunSpec: - describe("Cat"): - it("should use Eq for equality"): - val cat1 = Cat("Garfield", 38, "orange and black") - val cat2 = Cat("Heathcliff", 32, "orange and black") + it("Cat should use Eq for equality"): + val cat1 = Cat("Garfield", 38, "orange and black") + val cat2 = Cat("Heathcliff", 32, "orange and black") - cat1 === cat2 `shouldBe` false - cat1 =!= cat2 `shouldBe` true + cat1 === cat2 `shouldBe` false + cat1 =!= cat2 `shouldBe` true - val optionCat1 = Option(cat1) - val optionCat2 = Option.empty[Cat] + val optionCat1 = Option(cat1) + val optionCat2 = Option.empty[Cat] - optionCat1 === optionCat2 `shouldBe` false - optionCat1 =!= optionCat2 `shouldBe` true + optionCat1 === optionCat2 `shouldBe` false + optionCat1 =!= optionCat2 `shouldBe` true diff --git a/src/test/scala/ch01/printable/CatSpec.scala b/src/test/scala/ch01/printable/CatSpec.scala index 053744b..468ed41 100644 --- a/src/test/scala/ch01/printable/CatSpec.scala +++ b/src/test/scala/ch01/printable/CatSpec.scala @@ -5,7 +5,6 @@ import org.scalatest.matchers.should.Matchers.shouldBe import PrintableSyntax.* class CatSpec extends AnyFunSpec: - describe("Cat"): - it("should use Printable to print the cat"): - Cat("Garfield", 41, "ginger and black").format `shouldBe` - "Garfield is a 41 year-old ginger and black cat." + it("Cat should use Printable to print the cat"): + Cat("Garfield", 41, "ginger and black").format `shouldBe` + "Garfield is a 41 year-old ginger and black cat." diff --git a/src/test/scala/ch01/show/CatSpec.scala b/src/test/scala/ch01/show/CatSpec.scala index bdcf16f..0bab0cd 100644 --- a/src/test/scala/ch01/show/CatSpec.scala +++ b/src/test/scala/ch01/show/CatSpec.scala @@ -5,7 +5,6 @@ import org.scalatest.matchers.should.Matchers.shouldBe import cats.syntax.show.toShow class CatSpec extends AnyFunSpec: - describe("Cat"): - it("should use Show to print the cat"): - Cat("Garfield", 41, "ginger and black").show `shouldBe` - "Garfield is a 41 year-old ginger and black cat." + it("Cat should use Show to print the cat"): + Cat("Garfield", 41, "ginger and black").show `shouldBe` + "Garfield is a 41 year-old ginger and black cat." diff --git a/src/test/scala/ch04/MonadErrorSpec.scala b/src/test/scala/ch04/MonadErrorSpec.scala new file mode 100644 index 0000000..6ea3f9c --- /dev/null +++ b/src/test/scala/ch04/MonadErrorSpec.scala @@ -0,0 +1,15 @@ +package ch04 +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers.shouldBe +import scala.util.Try +import scala.util.Success + +class MonadErrorSpec extends AnyFunSpec: + it("should check if adult"): + val actual = MonadError.validateAdult[Try](18) + actual `shouldBe` Success(18) + + val actual2 = MonadError.validateAdult[Try](8) + // No easy way to get the exception out + // except for pattern matching. + actual2.isFailure `shouldBe` true diff --git a/src/test/scala/ch04/ReaderSpec.scala b/src/test/scala/ch04/ReaderSpec.scala new file mode 100644 index 0000000..a619e0c --- /dev/null +++ b/src/test/scala/ch04/ReaderSpec.scala @@ -0,0 +1,22 @@ +package ch04 +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers.shouldBe + +class ReaderSpec extends AnyFunSpec: + it("DbReader should be able to check password"): + val users = Map( + 1 -> "dade", + 2 -> "kate", + 3 -> "margo" + ) + + val passwords = Map( + "dade" -> "zerocool", + "kate" -> "acidburn", + "margo" -> "secret" + ) + + val db = Reader.Db(users, passwords) + + Reader.checkLogin(1, "zerocool").run(db) `shouldBe` true + Reader.checkLogin(4, "davinci").run(db) `shouldBe` false diff --git a/src/test/scala/ch04/StateSpec.scala b/src/test/scala/ch04/StateSpec.scala new file mode 100644 index 0000000..7d99044 --- /dev/null +++ b/src/test/scala/ch04/StateSpec.scala @@ -0,0 +1,8 @@ +package ch04 + +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers.shouldBe + +class StateSpec extends AnyFunSpec: + it("evalInput should be able to evaluate a post-order expression"): + State.evalInput("1 2 + 3 4 + *") `shouldBe` 21 diff --git a/src/test/scala/ch04/TreeSpec.scala b/src/test/scala/ch04/TreeSpec.scala new file mode 100644 index 0000000..d6ff871 --- /dev/null +++ b/src/test/scala/ch04/TreeSpec.scala @@ -0,0 +1,33 @@ +package ch04 + +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers.shouldBe +import cats.syntax.flatMap.toFlatMapOps +import cats.syntax.functor.toFunctorOps + +class TreeSpec extends AnyFunSpec: + it("Tree monad should support flatMap, map, and for-comprehension"): + val actual = branch(leaf(100), leaf(200)) + .flatMap(x => branch(leaf(x - 1), leaf(x + 1))) + val expected = branch( + branch(leaf(99), leaf(101)), + branch(leaf(199), leaf(201)) + ) + actual `shouldBe` expected + + val actual2 = for + a <- branch(leaf(100), leaf(200)) + b <- branch(leaf(a - 10), leaf(a + 10)) + c <- branch(leaf(b - 1), leaf(b + 1)) + yield c + val expected2 = branch( + branch( + branch(leaf(89), leaf(91)), + branch(leaf(109), leaf(111)) + ), + branch( + branch(leaf(189), leaf(191)), + branch(leaf(209), leaf(211)) + ) + ) + actual2 `shouldBe` expected2 diff --git a/src/test/scala/ch04/WriterSpec.scala b/src/test/scala/ch04/WriterSpec.scala new file mode 100644 index 0000000..d8b2985 --- /dev/null +++ b/src/test/scala/ch04/WriterSpec.scala @@ -0,0 +1,21 @@ +package ch04 +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers.shouldBe +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.* +import scala.concurrent.{Await, Future} + +class WriterSpec extends AnyFunSpec: + it("factorial should be able maintain the order of logging"): + val computations = Future.sequence( + Vector( + Future(Writer.factorial(5)), + Future(Writer.factorial(5)) + ) + ) + val actual = Await.result( + computations.map(_.map(_.written)), + 5.seconds + ) + val logs = Vector("fact 0 1", "fact 1 1", "fact 2 2", "fact 3 6", "fact 4 24", "fact 5 120") + actual `shouldBe` Vector(logs, logs)