Skip to content

Instantly share code, notes, and snippets.

@aoiroaoino
Last active September 16, 2019 06:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aoiroaoino/f017458c29d0b98eeaf239529f04af22 to your computer and use it in GitHub Desktop.
Save aoiroaoino/f017458c29d0b98eeaf239529f04af22 to your computer and use it in GitHub Desktop.
Scala秋祭り2019 発表資料内に登場するソースコードの元

検証環境とか

  • Scala 2.13.0 の REPL
  • ScalaCheck 1.14.0

補足とか

参考資料

(敬称略)

発表当日の質問

スクリーンショット 2019-09-16 15 17 36

Q&A 1

Q: ここでの「継続」と「関数」の違いは何でしょうか?「継続 = その後に計算される計算/処理」だとして、 Scalaにおいては関数が「継続」を表現するのにちょうど良いので、関数を使って「継続渡し」を説明している、という感じでしょうか? A:

Q&A 2

Q: for式で使いたいだけであればモナドにする必要はないと思いますが、モナドにしておく理由はありますか? A:

Q&A 3

Q: 継続を見て「これはモナドかも?」という直感が働くのはどういった点ですか? A:

def isEven(i: Int): Boolean = 1 % 2 == 0
val result = isEven(42)
if (result) {
println(s"$n is even number")
} else {
println(s"$n is odd number")
}
// ===
val result = isEven(42)
if (result) {
println(s"$n is even number") // 実行される
} else {
println(s"$n is odd number") // 実行されない(捨てられる)
}
// ===
def add(i: Int, j: Int): Int = i + j
def mul(i: Int, j: Int): Int = i * j
def show(i: Int): String = "num: " + i.toString
val a = add(1, 2)
val b = add(a, 3)
val s = show(b)
println(s)
// ===
def add(i: Int, j: Int): Int = i + j
def mul(i: Int, j: Int): Int = i * j
def show(i: Int): String = "num: " + i.toString
val a = add(1, 2) // 計算結果
// ↓ その後の計算
val b = mul(a, 3)
val s = show(b)
println(s)
// ===
def add(i: Int, j: Int): Int = i + j
def mul(i: Int, j: Int): Int = i * j
def show(i: Int): String = "num: " + i.toString
val a = add(1, 2)
val b = mul(a, 3) // 計算結果
// その後の計算
val s = show(b)
println(s)
// ===
def add(i: Int, j: Int): Int = i + j
def mul(i: Int, j: Int): Int = i * j
def show(i: Int): String = "num: " + i.toString
val a = add(1, 2)
val b = mul(a, 3)
val s = show(b) // 計算結果
// その後の計算
println(s)
// 通常の計算(直接スタイル)
def add(i: Int, j: Int): Int = i + j
// 継続渡しスタイル
def add[R](i: Int, j: Int)(cont: Int => R): R = {
val a = i + j
cont(a)
}
// 通常の計算(直接スタイル)
def show(i: Int): String = "num: " + i.toString
// 継続渡しスタイル
def show[R](i: Int)(cont: String => R): R = {
val a = "num: " + i.toString
cont(a)
}
// ===
scala> def add[R](i: Int, j: Int)(cont: Int => R): R = {
| val a = i + j
| cont(a)
| }
add: [R](i: Int, j: Int)(cont: Int => R)R
// 1 + 2 を実行。その後の計算で「10倍する」
scala> add(1, 2)(a => a * 10)
res0: Int = 30
// 1 + 2 を実行。その後の計算で「引数を返す」
scala> add(1, 2)(a => a)
res1: Int = 3
// ===
scala> def show[R](i: Int)(cont: String => R): R = {
| val a = "num: " + i.toString
| cont(a)
| }
show: [R](i: Int)(cont: String => R)R
// 1 を文字列に変換し prefix をつける。その後の計算「suffix をつける」
scala> show(1)(a => a + " !!")
res3: String = num: 1 !!
// 1 を文字列に変換し prefix をつける。その後の計算「Byte 配列を得る」
scala> show(1)(a => a.getBytes(java.nio.charset.StandardCharsets.UTF_8))
res4: Array[Byte] = Array(110, 117, 109, 58, 32, 49)
// ===
// 直接スタイル(再掲)
val a = add(1, 2)
val b = mul(a, 3)
val s = show(b)
println(s)
// ===
// 継続渡しスタイル
add(1, 2){ a =>
mul(a, 3){ b =>
show(b){ s =>
println(s)
}
}
}
// ===
// 継続ってどこだっけ?
def add[R](i: Int, j: Int)(cont: Int => R): R = {
val a = i + j
cont(a)
}
def show[R](i: Int)(cont: String => R): R = {
val a = "num: " + i.toString
cont(a)
}
// ===
// メソッド定義を再掲
def add[R](i: Int, j: Int)(cont: Int => R): R = { ??? }
def show[R](i: Int) (cont: String => R): R = { ??? }
// 継続部分を関数に変更
def add[R](i: Int, j: Int): (Int => R) => R = { cont => ??? }
def show[R](i: Int) : (String => R) => R = { cont => ??? }
// (A => R) => R という関数に Cont[R, A] と名前をつける
type Cont[R, A] = (A => R) => R
def add[R](i: Int, j: Int): Cont[R, Int] = { cont => ??? }
def show[R](i: Int) : Cont[R, String] = { cont => ??? }
// ===
// Scala ではしばしばラップするデータ型を定義
final case class Cont[R, A](run: (A => R) => R)
// もちろん、データ型 Cont[R, A] を用いて add, show が定義できる
def add[R](i: Int, j: Int): Cont[R, Int] = Cont { cont => ??? }
def show[R](i: Int) : Cont[R, String] = Cont { cont => ??? }
// ===
// Monad にしたい(for式で合成したい)
object Cont {
def pure[R, A](a: A): Cont[R, A] =
???
}
final case class Cont[R, A](run: (A => R) => R) {
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
???
def map[B](f: A => B): Cont[R, B] =
???
}
// ===
// pure の実装
object Cont {
def pure[R, A](a: A): Cont[R, A] =
Cont(ar => ar(a)) // pure の実装
}
final case class Cont[R, A](run: (A => R) => R) {
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
???
def map[B](f: A => B): Cont[R, B] =
???
}
// ===
// flatMap の実装 1
object Cont {
def pure[R, A](a: A): Cont[R, A] =
Cont(ar => ar(a))
}
final case class Cont[R, A](run: (A => R) => R) {
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
Cont(br => ???) // 計算結果を考える
def map[B](f: A => B): Cont[R, B] =
???
}
// ===
// flatMap の実装 2
object Cont {
def pure[R, A](a: A): Cont[R, A] =
Cont(ar => ar(a))
}
final case class Cont[R, A](run: (A => R) => R) {
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
Cont(br => run(a => f(a) /* Cont[R, B] */)) // 自身の継続として、引数の関数fにその結果を与える
def map[B](f: A => B): Cont[R, B] =
???
}
// ===
// flatMap の実装 3
object Cont {
def pure[R, A](a: A): Cont[R, A] =
Cont(ar => ar(a))
}
final case class Cont[R, A](run: (A => R) => R) {
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
Cont(br => run(a => f(a).run(br))) // fに結果を与えて得られた継続に外側の継続を渡す
def map[B](f: A => B): Cont[R, B] =
???
}
// ===
// map はモナドのデフォルト実装
object Cont {
def pure[R, A](a: A): Cont[R, A] =
Cont(ar => ar(a))
}
final case class Cont[R, A](run: (A => R) => R) {
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
Cont(br => run(a => f(a).run(br)))
def map[B](f: A => B): Cont[R, B] =
flatMap(a => Cont.pure(f(a))) // ひとまずモナドのデフォルト実装
}
// ===
sbt.version=1.2.8
ThisBuild / scalaVersion := "2.13.0"
ThisBuild / version := "0.1.0-SNAPSHOT"
ThisBuild / organization := "com.example"
ThisBuild / organizationName := "example"
lazy val root = (project in file("."))
.settings(
name := "check_cont_monad_law",
libraryDependencies += "org.scalacheck" %% "scalacheck" % "1.14.0" % Test
)
// ===
package example
import org.scalacheck.{Prop, Properties}
final case class Cont[R, A](run: (A => R) => R) {
def map[B](f: A => B): Cont[R, B] = Cont(br => run(f andThen br))
def flatMap[B](f: A => Cont[R, B]): Cont[R, B] =
Cont(br => run(f(_).run(br)))
}
object Cont {
def pure[R, A](a: A): Cont[R, A] = Cont(_(a))
}
class ContMonadSpec extends Properties("Monad[Cont[R, ?]]") {
def inc(i: Int): Cont[Int, Int] = Cont(_(i + 1))
def add_![R](s: String): Cont[String, String] = Cont(_(s + "!"))
def add_?[R](s: String): Cont[String, String] = Cont(_(s + "?"))
property("rightIdentity") = Prop.forAll { i: Int =>
inc(i).flatMap(Cont.pure).run(identity) == inc(i).run(identity)
}
property("leftIdentity") = Prop.forAll { i: Int =>
Cont.pure[Int, Int](i).flatMap(inc).run(identity) == inc(i).run(identity)
}
property("associativity") = Prop.forAll { s: String =>
Cont.pure(s).flatMap(add_!).flatMap(add_?).run(identity) ==
Cont.pure(s).flatMap(a => add_!(a).flatMap(add_?)).run(identity)
}
}
// ===
sbt:check_cont_monad_law> test
[info] + Monad[Cont[R, ?]].rightIdentity: OK, passed 100 tests.
[info] + Monad[Cont[R, ?]].leftIdentity: OK, passed 100 tests.
[info] + Monad[Cont[R, ?]].associativity: OK, passed 100 tests.
[info] Passed: Total 3, Failed 0, Errors 0, Passed 3
[success] Total time: 1 s, completed 2019/09/15 19:49:36
// ===
def add[R](i: Int, j: Int): Cont[R, Int] = Cont(ar => ar(i + j))
def mul[R](i: Int, j: Int): Cont[R, Int] = Cont(ar => ar(i * j))
def show[R](i: Int): Cont[R, String] = Cont(ar => ar(s"num: $i"))
def prog[R]: Cont[R, String] =
for {
a <- add(1, 2)
b <- mul(a, 3)
s <- show(b)
} yield {
s.toUpperCase
}
// ===
scala> prog.run(s => s.toList)
res18: List[Char] = List(N, U, M, :, , 9)
scala> prog.run(s => s.length)
res19: Int = 6
scala> prog.run(s => s)
res20: String = NUM: 9
// ===
def prog[R]: Cont[R, String] =
add(1, 2).flatMap { a =>
mul(a, 3).flatMap { b =>
show(b).map { s =>
s.toUpperCase
}
}
}
// fizzbuzz
def fizzCont(i: Int): Cont[String, Int] = Cont { ar =>
if (i % 3 == 0) {
"Fizz" // 継続(ar)を実行しないので計算が "Fizz" で終了する
} else {
ar(i) // 継続に i を渡し、残りの処理を実行する
}
}
// ===
def fizzCont(i: Int): Cont[String, Int] = Cont(ar => if (i % 3 == 0) "Fizz" else ar(i))
def buzzCont(i: Int): Cont[String, Int] = Cont(ar => if (i % 5 == 0) "Buzz" else ar(i))
def fizzBuzzCont(i: Int): Cont[String, Int] = Cont(ar => if (i % 15 == 0) "FizzBuzz" else ar(i))
def fizzBuzz(i: Int): Cont[String, Int] =
for {
a <- fizzBuzzCont(i)
b <- fizzCont(a)
c <- buzzCont(b)
} yield c
// ===
scala> LazyList.from(1).map(fizzBuzz(_).run(_.toString)).take(15).toList
res54: List[String] = List(1, 2, Fizz, 4, Buzz, Fizz, 7, 8, Fizz, Buzz, 11, Fizz, 13, 14, FizzBuzz)
// ===
import scala.util.Try
// Some だったら継続に値を渡して実行。None だったら継続を破棄して ifNone の結果を返す
def someValueOr[R, A](fa: Option[A])(ifNone: => R): Cont[R, A] =
Cont(ar => fa.fold(ifNone)(ar)) // Cont(fa.fold(ifNone))
// Success だったら継続に値を渡して実行。None だったら継続を破棄して ifFailure の結果を返す
def successValueOr[R, A](fa: Try[A])(ifFailure: Throwable => R): Cont[R, A] =
Cont(ar => fa.fold(ifFailure, ar)) // Cont(fa.fold(ifFailure, _))
// input から指定された key に対応する値を取り出し、数値型に変換する
def parseInt(input: Map[String, String], key: String): Cont[String, Int] =
for {
s <- someValueOr(input.get(key))(s"not found: $key")
i <- successValueOr(Try(s.toInt))(_.toString)
} yield i
// ===
// 例えば下記のような input を仮定する
val input = Map("name" -> "aoiroaoino", "age" -> "17")
parseInt(input, "address").run(i => s"result: $i")
parseInt(input, "name").run(i => s"result: $i")
parseInt(input, "age").run(i => s"result: $i")
// ===
scala> parseInt(input, "address").run(i => s"result: $i")
res99: String = not found: address
scala> parseInt(input, "name").run(i => s"result: $i")
res100: String = java.lang.NumberFormatException: For input string: "aoiroaoino"
scala> parseInt(input, "age").run(i => s"result: $i")
res101: String = result: 17
// ===
def using[A <: AutoCloseable, B](a: => A, n: String)(f: A => B): B =
try f(a) finally { a.close(); println(s"close: $n") }
// Scala 2.13.0 より scala.util.Using が入った
//話を簡略化するため、例外が投げられる resource を使用
def resource[R, A](resource: R)(body: (R) => A)(implicit releasable: Releasable[R]): A
// 継続モナドでラップ
def resource[R, A](a: => A)(implicit releasable: Releasable[A]): Cont[R, A] =
Cont(ar => Using.resource(a)(ar))
// ===
def lines(reader: BufferedReader): Iterator[String] =
Iterator.unfold(())(_ => Option(reader.readLine()).map(_ -> ()))
// Using.Magaer を使わず、しかも for 式で合成できる
val prog: Cont[List[String], List[String]] = for {
a <- resource(new BufferedReader(new FileReader("/tmp/file1.txt")))
b <- resource(new BufferedReader(new FileReader("/tmp/file2.txt")))
c <- resource(new BufferedReader(new FileReader("/tmp/file3.txt")))
} yield {
(lines(a) ++ lines(b) ++ lines(c)).toList
}
scala> prog.run(identity)
res140: List[String] = List(Hello, 1, Hello, 2, Hello, 3)
// ===
// close() 実行時に println する独自 Releasable を定義
def r(n: String): Using.Releasable[AutoCloseable] =
_.close().tap(_ => println(s"call close() for $n"))
// Releasable を明示的に渡して実行
val prog: Cont[List[String], List[String]] = for {
a <- resource(new BufferedReader(new FileReader("/tmp/file1.txt")))(r("f1"))
b <- resource(new BufferedReader(new FileReader("/tmp/file2.txt")))(r("f2"))
c <- resource(new BufferedReader(new FileReader("/tmp/file3.txt")))(r("f3"))
} yield {
(lines(a) ++ lines(b) ++ lines(c)).toList
}
scala> prog.run(identity)
call close() for f3
call close() for f2
call close() for f1
res145: List[String] = List(Hello, 1, Hello, 2, Hello, 3)
// ===
type User = String
type Result = String
def findAll[R](): Cont[Result, List[User]] = Cont { ar =>
try {
ar(List("foo", "bar")) // DB への問い合わせは成功したとする
} catch {
case _: Throwable => "query execution error"
}
}
// ===
// 全て成功する場合
val prog: Cont[Result, List[User]] =
for {
users <- findAll()
names <- Cont.pure(users.map(n => s"[$n]"))
} yield names
scala> prog.run(_.mkString(", "))
res132: Result = [foo], [bar]
// ===
// 継続の深いところで例外が投げられた場合
val prog: Cont[Result, List[User]] =
for {
users <- findAll()
names <- Cont.pure(users.map(n => s"[$n]"))
_ = 1 / 0 // java.lang.ArithmeticException: / by zero
} yield names
scala> prog.run(_.mkString(", "))
res133: Result = query execution error
// ===
// 後続の処理をなんとなく二回実行するやつ
def twice: Cont[Unit, Unit] = Cont { ar => ar(()); ar(()) }
// 絶対に一回しか実行して欲しくない DB への書き込み
def insert(s: String): Cont[Unit, Unit] = Cont { ar =>
println("insert data to database")
ar(())
}
val prog: Cont[Unit, Unit] =
for {
_ <- twice
_ <- insert("some data")
} yield ()
// めでたく二回実行されてしまいました
scala> prog.run(identity)
insert data to database
insert data to database
// M[_] はモナドの型
def pure[A](a: A): M[A]
def flatMap[B](f: A => M[B]): M[B]
// 右単位元
fa.flatMap(a => pure(a)) == fa
// 左単位元
pure(a).flatMap(f) == f(a)
// 結合律
fa.flatMap(f).flatMap(g) == fa.flatMap(a => f(a).flatMap(g))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment