Skip to content

Instantly share code, notes, and snippets.

@gakuzzzz
Last active October 29, 2019 06:19
Show Gist options
  • Star 16 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gakuzzzz/7848606 to your computer and use it in GitHub Desktop.
Save gakuzzzz/7848606 to your computer and use it in GitHub Desktop.
「Javaで継続モナド」をScalaに翻訳/Scala Advent Calendar 2013

「Javaで継続モナド」をScalaに翻訳

この記事はScala Advent Calendar 2013の7日目の記事です。

昨日は @shogogg さんのScala + sbt-android + IntelliJ で快適Androidアプリ開発でした。

明日は @takezoux2 さんのScalaのParserCombinator実践入門+です。

継続モナドを調べていたら、@terazzo さんのJavaで継続モナドという記事が非常に判りやすかったんですが、サンプルコードがJavaのボイラープレートの嵐でちょっと読むのが辛い感じだったのでScalaで翻訳してみました、というのがこの記事です。

なので基本的に解説の中身は @terazzo さんの記事をご参照ください。一応、元記事を引用して、その後にScalaの翻訳コードを載せるという形式で進めていきます。

では早速いきましょう。

継続渡し形式とは

(※前に書いた「内部イテレータを外部イテレータに手動で変換してみる実験 - terazzoの日記」も参照)

良くある関数の使い方は、関数を呼び出して戻り値を受け取り、その戻り値を使って次の処理をおこなうという形である。

例えば数値を二乗するメソッドと、その結果をプリントする処理があるとする。素直に実装すると下のようになる。

   /** xを二乗して戻すメソッド */
   public int dup(int x) {
       return x * x;
   }
...
   // 使う側
   // 10をdup()して結果をprintln()する
   int x = 10;
   int result = dup(x);
   System.out.println("result = " + result);

このコードをScalaに直すと以下になります。この辺は特に説明要らないですね。

    /** xを二乗して戻すメソッド */
    def dup(x: Int): Int = x * x
...
    // 使う側
    // 10をdup()して結果をprintln()する
    val x = 10
    val result = dup(x)
    println("result = " + result)

それに対し、CPS(継続渡し形式)というのは、関数の戻り値を使って「次の処理」をおこなう代わりに、「次の処理」をおこなう関数をあらかじめ引数と一緒に渡し、戻り値を戻す代わりにその戻り値を引数にその「次の処理」を呼んでもらうスタイルである。

   // 計算結果をintで受け取り、それをプリントするクラス
   public static class PrintResultHandler {
       public void printResult(int result) {
           System.out.println("result = " + result);
       }
   }
   // dup()の継続渡し版。結果をreturnする代わりにresultHandlerのprintResult()を呼ぶ */
   public void dupCps(int x, PrintResultHandler resultHandler) {
       int result = x * x;
       resultHandler.printResult(result);
   }
...
   // 使う側
   // return値を受け取って処理する代わりにPrintResultHandlerを渡す
   int x = 10;
   dupCps(x, new PrintResultHandler());

これをベースに、継続モナドの実装を考えていく。

    // 計算結果をintで受け取り、それをプリントするクラス
    class PrintResultHandler {
        def printResult(result: Int): Unit = println("result = " + result)
    }
    // dup()の継続渡し版。結果をreturnする代わりにresultHandlerのprintResult()を呼ぶ */
    def dupCps(x: Int, resultHandler: PrintResultHandler): Unit = {
        val result = x * x
        resultHandler.printResult(result)
    }
...
    // 使う側
    // return値を受け取って処理する代わりにPrintResultHandlerを渡す
    val x = 10
    dupCps(x, new PrintResultHandler())

継続渡し形式のメソッドの関数化

上の例のdupCpsを関数的に書いてみる。

関数を表すために、Google guavaのFunctionクラスか、自分で書くなら下のようなクラスを用意する。

/**
* @param <T> 引数の型
* @param <R> 戻り値の型
*/
public abstract class Function<T, R> {
   public abstract R apply(T value);
}

この Function クラスについては、Scala標準の Function1 を使用するので割愛します。

これを使って、まず、PrintResultHandlerを関数で表現する。

   Function<Integer, Void> printResult = new Function<Integer, Void>() {
       public Void apply(Integer value) {
           System.out.println("result = " + value);
           return null;
       }
   };

次に、dupCps自体を関数として書く。

   public Function<Integer, Void> dupCps = new Function<Integer, Void>() {
       public Void apply(Integer x) {
           int result = x * x;
           return printResult.apply(result); // printResultはどうやって渡す?
       }
   };

printResultdupCps は同時に記載します。

   val printResult: Int => Unit = value => println("result = " + value)

   val dupCps: Int => Unit = x => {
     val result = x * x
     printResult(result); // printResultはどうやって渡す?
   }

printResultをどうやって渡そうか?引数を数値ではなく数値 + printResultを取れるようにしても良いのだけど、「あとからprintResultを渡すと処理が実行される」ように、dupを遅延処理的に書くことで対処してみる。

   public Function<Function<Integer, Void>, Void> makeDupCps(final int x) {
       return new Function<Function<Integer, Void>, Void>() {
           public Void apply(Function<Integer, Void> resultHandler) {
               int result = x * x;
               return resultHandler.apply(result);
           }
       };
   }
...
   // 使う側
   // return値を受け取って処理する代わりにprintResult関数を渡す
   int x = 10;
   makeDupCps(x).apply(printResult);

Scalaに直すと以下になります。

    def makeDupCps(x: Int): (Int => Unit) => Unit = {
      resultHandler: (Int => Unit) => {
        val result = x * x
        resultHandler(result)
      }
    }

...
    // 使う側
    // return値を受け取って処理する代わりにprintResult関数を渡す
    val x = 10
    makeDupCps(x)(printResult)

=> は Haskell と同じく右結合なので、Int => Unit を受け取って Unit を返す関数の型は (Int => Unit) => Unit と明示的に括弧をつける必要があります。

上のint result = x * x;return resultHandler.apply(result);を見る通り、処理をおこなっている部分の記述自体はあくまでCPSのままであるが、「後の処理」は後から引数で渡せるようになった。「後の処理」を付け替えることも出来て便利そう。

   // xを二乗してHTMLで出力
   int x = 10;
   Function<Integer, Void> printResultInHtml = new Function<Integer, Void>() {
       public Void apply(Integer value) {
           System.out.printf(
                   "<html>\n" +
                   "<head><title>Result</title><head>\n" +
                   "<body>\n" +
                   "Result = %d\n" +
                   "</body>\n" +
                   "</html>\n", value);
           return null;
       }
   };
   makeDupCps(x).apply(printResultInHtml);

Multi-Line String Literals と String Interpolation 便利ー。

    // xを二乗してHTMLで出力
    val x = 10
    val printResultInHtml: Int => Unit =
      value => println(s"""
        <html>
        <head><title>Result</title><head>
        <body>
        Result = ${value}
        </body>
        </html>
      """)

    makeDupCps(x)(printResultInHtml)

さらに汎用性を高めて、makeDupCps()の部分を「二乗する」処理以外の処理も受け取れるようにしてみる。

   public Function<Function<Integer, Void>, Void> makeCps(
           final Function<Function<Integer, Void>, Void> run) {

       return new Function<Function<Integer, Void>, Void>() {
           public Void apply(Function<Integer, Void> resultHandler) {
               return run.apply(resultHandler);
           }
       };
   }
   public Function<Function<Integer, Void>, Void> makeDupCps(final int x) {
       Function<Function<Integer, Void>, Void> dup =
           new Function<Function<Integer,Void>, Void>() {
               public Void apply(Function<Integer, Void> resultHandler) {
                   int result = x * x;
                   return resultHandler.apply(result);
               }
            };

       return makeCps(dup);
   }

このmakeCps()の引数の「run」の部分、つまり「引数に関数をとり、自分の仕事の最後でその関数(継続)を呼ぶ」関数、こそが、継続モナドにおけるモナドの「中身」になる。

    def makeCps(run: (Int => Unit) => Unit): (Int => Unit) => Unit = run

    def makeDupCps(x: Int): (Int => Unit) => Unit = {
      val dup = { resultHandler: (Int => Unit) => 
        val result = x * x
        resultHandler(result)
      }
      makeCps(dup)
    }

継続渡し形式の関数のクラス化

上のmakeCps()をクラス化するよ。名前はCpsクラスということにする。*1

まず、フィールドを用意して「run」を保持出来るようにする。あと引数の型と最終的な戻り値の型を変えられるように型パラメータにしておく。

/**
* CPSな関数を包む継続モナド。
* @param <T> このCPSの持つ中間的な値の型
* @param <R> 最終的な戻り値の型
*/
public class Cps<T, R> {
   private final Function<Function<T, R>, R> run;

   public Cps(Function<Function<T, R>, R> run) {
       this.run = run;
   }
/*

後続処理を渡して処理を実行出来るように、実行用のメソッドを用意する。

*/
   /**
    * @param k 後続処理
    * @return 後続処理kを指定してrunを実行した結果を戻す。
    */
   public R run(Function<T, R> k) {
       return run.apply(k);
   }
/*

モナドを作るためのunitを定義する。valueを渡すと「何もせず後続処理にvalueを渡す関数」を作って使用するようにする。

*/
   // 単にコンストラクタを呼ぶ。
   public static <T, R> Cps<T, R> cps(Function<Function<T, R>, R> run) {
       return new Cps<T, R>(run);
   }
   // valueから「何もせずvalueを渡す関数」を作って、それを含むCpsを作る。
   public static <T, R> Cps<T, R> unit(final T value) {
       Function<Function<T, R>, R> passthrough = makePassthrough(value);
       return cps(passthrough);
   }
   // 何もせずvalueを渡す関数を作る
   private static <T, R> Function<Function<T, R>, R> makePassthrough(final T value) {
       return new Function<Function<T, R>, R>() {
           public R apply(Function<T, R> k) {
               return k.apply(value);
           }
       };
   }
   // 関数として使う時用
   public static <T, R> Function<T, Cps<T, R>> unit() {
       return new Function<T, Cps<T,R>>() {
           public Cps<T, R> apply(T value) {
               return unit(value);
           }
       };
   }
   // あとでbindを定義する。
}

長いコードですが Scala版は以下のようにしました。

/**
 * CPSな関数を包む継続モナド。
 * @param <T> このCPSの持つ中間的な値の型
 * @param <R> 最終的な戻り値の型
 */
case class Cps[T, R](run: (T => R) => R) {
  // あとでbindを定義する。
}
object Cps {

  // valueから「何もせずvalueを渡す関数」を作って、それを含むCpsを作る。
  def unit[T, R](value: T): Cps[T, R] = {
    val passthrough: (T => R) => R = makePassthrough(value)
    Cps(passthrough)
  }

  // 何もせずvalueを渡す関数を作る
  private def makePassthrough[T, R](value: T): (T => R) => R = {
    k: (T => R) => k(value)
  }

}

runメソッドはそもそも cps.run(f) といった感じであたかもメソッドの様に runフィールドのapplyが呼べるので割愛。

cpsメソッドもcaseクラスにしたため、applyが自動で定義されるので省略しました。

Functionを返すunit()unit _で作れるのでこれも省略。

という訳で非常にシンプルになりました。

bindの定義は後でやるとして、一旦ここまでのクラスの実装で、「二乗して、それからプリントする」処理を書いてみる。

   // makeDupCpsのCpsクラス対応バージョン
   // xを二乗した値を後続処理に渡すCpsを戻す
   public Cps<Integer, Void> makeDupCps(final int x) {
       Function<Function<Integer, Void>, Void> dup =
           new Function<Function<Integer,Void>, Void>() {
               public Void apply(Function<Integer, Void> resultHandler) {
                   int result = x * x;
                   return resultHandler.apply(result);
               }
            };
       
       return new Cps<Integer, Void>(dup);
   }    
...
   // 使う側
   // xを二乗して、それからプリントする
   int x = 10;
   makeDupCps(x).run(printResult);

Functionのみで実装した例と比べると、makeDupCps()の戻り値がFunctionからCpsに変わったところと、後続処理を渡して実行するメソッド名がapply()からrun()に変わっているところだけが変更点で、それ以外はまったく同じであることが分かると思う。

    // makeDupCpsのCpsクラス対応バージョン
    // xを二乗した値を後続処理に渡すCpsを戻す
    def makeDupCps(x: Int): Cps[Int, Unit] = {
      val dup: (Int => Unit) => Unit = { resultHandler: (Int => Unit) =>
        val result = x * x
        resultHandler(result)
      }
      Cps(dup)
    }    
...
    // 使う側
    // xを二乗して、それからプリントする
    val x = 10
    makeDupCps(x).run(printResult)

x毎にxを二乗する関数を作成しないといけないのが残念な感じなので、「xを次に渡す処理」と「何かを受け取って二乗して次に渡す処理」に分割し、両者の合成として処理を書きたい。

モナドで処理の合成といえばbindなのでbindを実装する。

public class Cps<T, R> {
... //上のCpsの実装の続き
   /**
    * @param <S> 次の処理の持つ中間的な値の型
    * @param f TからCps<S, R>を作る関数
    * @return 自分自身のrun処理とfの処理を合成したCpsを戻す。
    */
   public <S> Cps<S, R> bind(final Function<T, Cps<S, R>> f) {
       return new Cps<S, R>(new Function<Function<S, R>, R>()  {
           public R apply(final Function<S, R> k) {
               return run(new Function<T, R>() {
                   public R apply(T value) {
                       return f.apply(value).run(k);
                   }
               });
           }
       });
   }
}

bindで戻される新たなCpsは、kを受け取り、「fを実行した後kを実行する」という後続処理を指定して元のCpsのrunを実行するものとなる。

結果として、run、f、kの形で処理が実行され、処理が合成される形になる。

class Cps[T, R] {
... //上のCpsの実装の続き

  /**
   * @param <S> 次の処理の持つ中間的な値の型
   * @param f TからCps<S, R>を作る関数
   * @return 自分自身のrun処理とfの処理を合成したCpsを戻す。
   */
  def bind[S](f: T => Cps[S, R]): Cps[S, R] = {
    Cps[S, R]((k: S => R) => run(value => f(value).run(k)))
  }
}

さっきの使用例を書き直してみる。

   // 使う側
   Function<Integer, Cps<Integer, Void>> dup = new Function<Integer, Cps<Integer, Void>>() {
       public Cps<Integer, Void> apply(Integer value) {
           return Cps.unit(value * value);
       }
   };
   Function<Integer, Void> printResult = new Function<Integer, Void>() {
       public Void apply(Integer value) {
           System.out.println("result = " + value);
           return null;
       }
   };

   int x = 10;
   // xを、二乗して、プリントする
   Cps.<Integer, Void>unit(x).bind(dup).run(printResult);

「なにもせずxを渡すCps」と「何かを受け取って二乗する処理」と「何かを受け取ってプリントする処理」のチェーンとして表現されている。

dupの型がさっきと変わっているので注意。

    // 使う側
    val dup: Int => Cps[Int, Unit] = value => Cps.unit(value * value)

    val printResult: Int => Unit = value => println("result = " + value)

    val x = 10
    // xを、二乗して、プリントする
    Cps.unit(x).bind(dup).run(printResult)

もう少し使用例

途中で処理の対象となる型が変わっても大丈夫という例を挙げてみる。

処理の全体像としては、「「10,20」のように数値二個がカンマで区切られた文字列を受け取り、それを足しあわせて、プリントする」というのを考える。

処理の途中で使う、整数のペアを保持出来るクラスを作っておく。(二組ペアなのでDoubleにしたいが紛らわしいのBinaryValueにする。)

public class BinaryValue {
   public final Integer left;
   public final Integer right;
   private BinaryValue(Integer left, Integer right) {
       this.left = left;
       this.right = right;
   }
   public static BinaryValue of(Integer left, Integer right) {
       return new BinaryValue(left, right);
   }
   @Override
   public int hashCode() {
       return HashCodeBuilder.reflectionHashCode(this);
   }
   @Override
   public boolean equals(Object obj) {
       return EqualsBuilder.reflectionEquals(this, obj);
   }
}

個々の処理を書いて行くよ。

まず、「「10,20」のように数値二個がカンマで区切られた文字列を受け取ってBinaryValueを次に渡す」処理を実装する。

public final class CpsTest {
...
   private static <R> Function<String, Cps<BinaryValue, R>>parse() {
       return new Function<String, Cps<BinaryValue,R>>() {
           public Cps<BinaryValue, R> apply(String value) {
               String[] components = value.split(",");
               Integer left = Integer.valueOf(components[0]);
               Integer right = Integer.valueOf(components[1]);

               return Cps.unit(new BinaryValue(left, right));
           }
       };
   }

最終的な戻り値の型を可変に出来るように、型パラメータ化&メソッド化してみた。エラー処理とかは省略している。

次に、「BinaryValueを受け取って中身を足しあわせて次に渡す」処理を実装。

   private static <R> Function<BinaryValue, Cps<Integer, R>>sum()  {
       return new Function<BinaryValue, Cps<Integer,R>>() {
           public Cps<Integer, R> apply(BinaryValue value) {
               return Cps.unit(value.left + value.right);
           }
       };
   }

結果をプリントする処理printResultの実装はさっきのと同じ。

ここは一塊で見たほうがわかり易いので一括で。

case class BinaryValue(left: Int, right: Int)

object CpsTest {

  def parse[R]: String => Cps[BinaryValue, R] = { value =>
     val Array(left, right) = value.split(",")
     Cps.unit(BinaryValue(left.toInt, right.toInt))
  }

  def sum[R]: BinaryValue => Cps[Int, R] = {
    case BinaryValue(left, right) => Cps.unit(left + right)
  }

}

19行あったBinaryValueが1行になりました。

では「「10,20」のように数値二個がカンマで区切られた文字列を受け取り、それを足しあわせて、プリントする」の処理を書いてみる。

   // 「10,20」から数値二個を読み取り、足しあわせ、プリントする
   Cps.<String, Void>unit("10,20")     // 「10,20」から
       .bind(CpsTest.<Void>parse())    // 数値二個を読み取り
       .bind(CpsTest.<Void>sum())      // 足しあわせ
       .run(printResult);              // プリント

実行してみる。

result = 30

足せている。

  // 「10,20」から数値二個を読み取り、足しあわせ、プリントする
  Cps.unit[String, Unit]("10,20")   // 「10,20」から
    .bind(CpsTest.parse[Unit])      // 数値二個を読み取り
    .bind(CpsTest.sum[Unit])        // 足しあわせ
    .run(printResult)               // プリント

同じ関数を再利用して「プリントする」代わりに「値を戻す」にしてみる。

「何もせず値を戻す」関数を定義

   private static <T> Function<T, T>id() {
       return new Function<T, T>() {
           public T apply(T value) {
               return value;
           }
       };
   }

Scalaには標準で identity があるので省略

「「10,20」のように数値二個がカンマで区切られた文字列を受け取り、それを足しあわせて、結果を戻す」に変更。

   @Test
   public void testCps() {
       Function<Integer, Integer> id = id();

       // 「10,20」から数値二個を読み取り、足しあわせ、戻す。
       Integer result =
           Cps.<String, Integer>unit("10,20")
               .bind(CpsTest.<Integer>parse())
               .bind(CpsTest.<Integer>sum())
               .run(id);

       assertEquals(30, result.intValue());
   }
  def restCps(): Unit = {
    val result: Int =
      Cps.unit[String, Int]("10,20")
        .bind(CpsTest.parse[Int])
        .bind(CpsTest.sum[Int])
        .run(identity)

    assert(30 == result)
  }

再帰処理の継続モナド化

「内部イテレータを外部イテレータに手動で変換してみる実験 - terazzoの日記」の時にやった階乗の計算処理を継続モナド化してみる。

まずは元のコード

   /** nの階乗数をresultHandlerにputResult()する*/
   public void factCps(final int n, final ResultHandler resultHandler) {
       if (n == 0) {
           resultHandler.putResult(1);
       } else {
           factCps(n - 1, new ResultHandler() {
               /* n - 1の階乗の結果がresultとして渡ってくるはず */
               public void putResult(int result) {
                   resultHandler.putResult(n * result);
               }
           });
       }
   }

   public void testFactCps() {
       factCps(10, new ResultHandler() {
           public void putResult(int result) {
               assertEquals("10の階乗は3628800", 3628800, result);
           }
       });
   }
  trait ResultHandler {
    def putResult(result: Int): Unit
  }

  /** nの階乗数をresultHandlerにputResult()する*/
  def factCps(n: Int, resultHandler: ResultHandler): Unit = {
    if (n == 0) {
      resultHandler.putResult(1)
    } else {
      factCps(n - 1, new ResultHandler() {
        /* n - 1の階乗の結果がresultとして渡ってくるはず */
        def putResult(result: Int): Unit = {
          resultHandler.putResult(n * result)
        }
      })
    }
  }

  def testFactCps(): Unit = {
    factCps(10, new ResultHandler() {
      def putResult(result: Int): Unit = {
        assert(result == 3628800, "10! == 3628800")
      }
    })
  }

「n==0の時1を、それ以外の時はfactCps(n - 1, resultHandler)のresultHandlerの部分に、最終結果をn倍する処理を渡す」という形。

これを素直に書き直すと、次のような感じか。

   private static <R> Function<Integer, Cps<Integer, R>> fact() {
       return  new Function<Integer, Cps<Integer,R>>() {
           public Cps<Integer, R> apply(final Integer n) {
               System.out.printf("fact(%d):\n", n);
               if (n == 0) {
                   return Cps.unit(1);
               } else {
                   return
                       CpsTest.<R>fact().apply(n - 1)
                           .bind(new Function<Integer, Cps<Integer,R>>() {
                               public Cps<Integer, R> apply(Integer value) {
                                   return Cps.unit(n * value);
                               }
                           });
               }
           }
       };
   }
   @Test
   public void testFact() {
       Function<Integer, Integer> id = id();
       Integer result =
           Cps.<Integer, Integer>unit(10)
               .bind(CpsTest.<Integer>fact())
               .run(id);
       assertEquals("10の階乗は3628800", 3628800, result.intValue());
   }

確かに最終結果を渡す部分は分離出来ているけど、中でfact()を呼んでいる部分がイマイチな気がする。

再帰的にfact()を呼ぶ部分自体も継続として扱えないだろうか。

  def fact[R]: Int => Cps[Int, R] = { n: Int =>
    println(s"fact(${n})")
    if (n == 0) Cps.unit(1)
    else fact(n - 1).bind(value => Cps.unit(n * value))
  }

  def testFact(): Unit = {
    val result = Cps.unit[Int, Int](10).bind(fact).run(identity)
    assert(result == 3628800, "10! == 3628800")
  }

後続処理の戻り値を使うということが出来ないので、アキュムレータを使って書き直す。Functionには引数が一つしか渡せないのでタプル的なクラスを作ってまとめる。

   private static class Arg {
       public final int value;	// 本来のパラメータ
       public final int acc;	// 結果を累積した値

       private Arg(int value, int acc) {
           this.value = value;
           this.acc = acc;
       }
       public static  Arg of(int value, int acc) {
           return new Arg(value, acc);
       }
   }
   private static <R> Function<Arg, Cps<Arg, R>> fact() {
       return  new Function<Arg, Cps<Arg,R>>() {
           public Cps<Arg, R> apply(final Arg arg) {
               if (arg.value == 0) {
                   return Cps.unit(arg);
               } else {
                   return Cps.unit(Arg.of(arg.value - 1, arg.acc * arg.value));
               }
           }
       };
   }
  case class Arg(value: Int, acc: Int)

  def fact[R]: (Arg) => Cps[Arg, R] = {
    case Arg(value, acc) =>
      if (value == 0) Cps.unit(Arg(value, acc))
      else Cps.unit(Arg(value - 1, acc * value))
  }

Scala はパターンマッチが使えるので値の分解も楽々ですね。

これを使う側は……再帰の回数分factをbindしてやれば動く。

   @Test
   public void testFact() {
       Function<Arg, Arg> id = id();
       Function<Arg, Cps<Arg, Arg>> fact = fact();
       int result =
           Cps.<Arg, Arg>unit(Arg.of(10, 1))
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .bind(fact)
               .run(id).acc;

       assertEquals("10の階乗は3628800", 3628800, result);
   }
  def testFact(): Unit = {
    val Arg(value, acc) =
      Cps.unit[Arg, Arg](Arg(10, 1))
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .bind(fact)
        .run(identity)
      assert(acc == 3628800, "10! == 3628800")
    }

回数分並べるのは実用的じゃない気がするので、終了したかどうかのフラグを持たせて、終了するまで繰り返しbind(fact)するようにしてみる。

   private static class Arg {
       public final int value;
       public final int acc;
       public final boolean isAtEnd;

       private Arg(int value, int acc, boolean isAtEnd) {
           this.value = value;
           this.acc = acc;
           this.isAtEnd = isAtEnd;
       }
       public static Arg of(int value, int acc) {
           return new Arg(value, acc, false);
       }
       public static Arg forResult(int acc) {
           return new Arg(0, acc, true);
       }
   }
   private static <R> Function<Arg, Cps<Arg, R>> fact() {
       return  new Function<Arg, Cps<Arg,R>>() {
           public Cps<Arg, R> apply(final Arg arg) {
               if (arg.value == 0) {
                   return Cps.unit(Arg.forResult(arg.acc)); // 終了の場合
               } else {
                   return Cps.unit(Arg.of(arg.value - 1, arg.acc * arg.value));
               }
           }
       };
   }
   @Test
   public void testFact() {
       Function<Arg, Arg> id = id();
       Function<Arg, Cps<Arg, Arg>> fact = fact();

       Cps<Arg, Arg> iter = Cps.<Arg, Arg>unit(Arg.of(10, 1));
       // 終了かどうかを確認し、それ以外ならさらにfactをbind
       while (!iter.run(id).isAtEnd) {
           iter = iter.bind(fact);
       }
       int result = iter.run(id).acc;

       assertEquals("10の階乗は3628800", 3628800, result);
   }
  case class Arg private (value: Int, acc: Int, isAtEnd: Boolean)
  object Arg {
    def of(value: Int, acc: Int): Arg = Arg(value, acc, false)
    def forResult(acc: Int): Arg = Arg(0, acc, true)
  }

  def fact[R]: Arg => Cps[Arg, R] = {
    case Arg(value, acc, isAtEnd) =>
      if (value == 0) Cps.unit(Arg.forResult(acc)) // 終了の場合
      else Cps.unit(Arg.of(value - 1, acc * value))
  }

  def testFact(): Unit = {
    val init: Cps[Arg, Arg] = Cps.unit(Arg.of(10, 1))
    def loop(iter: Cps[Arg, Arg]): Cps[Arg, Arg] = {
      if (iter.run(identity).isAtEnd) iter
      else loop(iter.bind(fact))
    }
    val result: Int = loop(init).run(identity).acc
    assert(result == 3628800, "10! == 3628800")
  }

while ループがアレげだったので末尾再起なloop内部関数を定義しました。

一応動く。が、問題点が二個ある。

一つは、終了フラグを取り出すのにrun(id)しているが、その度に計算全体が実行されること。

処理過程が分かるようにprint文を入れると次のようになる。

// System.out.printf("fact(%d,%d):\n", arg.value, arg.acc)を挿入して出した結果
fact(10,1):
fact(10,1):
fact(9,10):
fact(10,1):
fact(9,10):
fact(8,90):
fact(10,1):
fact(9,10):
fact(8,90):
fact(7,720)
...
fact(2,1814400):
fact(1,3628800):
fact(0,3628800):

Cpsを使ったチェーンは、実は最後にrunするまでは実行されない構造になっていたのだ。

もう一つは、この最後にrunするまでは実行されない構造により発生する問題で、値を大きくするとStack Overflowが発生する。*2

factをbindする際に、一旦結果を取り出して包み直すようにすれば大丈夫。

   @Test
   public void testFact() {
       Function<Arg, Arg> id = id();
       Function<Arg, Cps<Arg, Arg>> fact = fact();
       Cps<Arg, Arg> iter = Cps.<Arg, Arg>unit(Arg.of(10, 1));
       while (!iter.run(id).isAtEnd) {
           iter = Cps.unit(iter.run(id));
           iter = iter.bind(fact);
       }
       int result = iter.run(id).acc;
       assertEquals("10の階乗は3628800", 3628800, result);
   }

これで仮に10を100000にしてもStack Overflowがおこらなくなった。内容的にはトランポリンの継続モナド版とでもいうような動きになっている。

もう少し賢い方法があるような気がするんだけど思いつかない。

後で出て来るcall/ccで上手く書けないかと思ったけどやり方が分からない……。

  def testFact(): Unit = {
    val init: Cps[Arg, Arg] = Cps.unit(Arg.of(10, 1))
    def loop(iter: Cps[Arg, Arg]): Cps[Arg, Arg] = {
      if (iter.run(identity).isAtEnd) iter
      else loop(Cps.unit(iter.run(identity)).bind(fact))
    }
    val result: Int = loop(init).run(identity).acc
    assert(result == 3628800, "10! == 3628800")
  }

「ルールは大事よね」

等しいということは直接テストできないけど、とりあえず同じ入力に対する出力が一致することを確認する。

ちなみにモナド則はmonad lawsまたはmonad axiomsらしいです。

   // 拡張スタイルのモナド則
   // (return x) >>= f ≡ f x
   @Test
   public void testRule1() {
       Function<BinaryValue, BinaryValue> id = id();
       Function<String, Cps<String, BinaryValue>> unit = Cps.unit();
       Function<String, Cps<BinaryValue, BinaryValue>> f = parse();
       String x = "10,20";

       assertEquals(
           unit.apply(x).bind(f).run(id),
           f.apply(x).run(id)
       );
   }
   
   // m >>= return ≡ m
   @Test
   public void testRule2_simple() {
       Function<String, String> id = id();
       Function<String, Cps<String, String>> unit = Cps.unit();
       Cps<String, String> m = Cps.unit("abc");

       assertEquals(
           m.bind(unit).run(id),
           m.run(id)
       );
   }
   @Test
   public void testRule2_bound() {
       Function<BinaryValue, BinaryValue> id = id();
       Function<BinaryValue, Cps<BinaryValue, BinaryValue>> unit = Cps.unit();
       Cps<BinaryValue, BinaryValue> m =
           Cps.<String, BinaryValue>unit("10,20")
               .bind(CpsTest.<BinaryValue>parse());

       assertEquals(
           m.bind(unit).run(id),
           m.run(id)
       );
   }

   // (m >>= f) >>= g ≡ m >>= ( \x -> (f x >>= g) )
   @Test
   public void testRule3() {
       Function<Integer, Integer> id = id();
       final Function<String, Cps<BinaryValue, Integer>> f = parse();
       final Function<BinaryValue, Cps<Integer, Integer>> g = sum();
       
       Cps<String, Integer> m =  Cps.<String, Integer>unit("10,20");

       assertEquals(
               m.bind(f).bind(g).run(id),

               m.bind(new Function<String, Cps<Integer, Integer>>() {
                   public Cps<Integer, Integer> apply(String value) {
                       return f.apply(value).bind(g);
                   }
               }).run(id)
       );
   }
  // (return x) >>= f ≡ f x
  def testRule1(): Unit = {
    val f = CpsTest.parse[BinaryValue]
    val x = "10,20"

    val left = Cps.unit(x).bind(f).run(identity)
    val right = f(x).run(identity)
    assert(left == right)
  }

  // m >>= return ≡ m
  def testRule2_simple(): Unit = {
    val unit: String => Cps[String, String] = Cps.unit[String, String] _
    val m: Cps[String, String] = Cps.unit("abc")
    assert(m.bind(unit).run(identity) == m.run(identity))
  }

  def testRule2_bound(): Unit = {
    val unit: BinaryValue => Cps[BinaryValue, BinaryValue] = Cps.unit[BinaryValue, BinaryValue] _
    val m: Cps[BinaryValue, BinaryValue] = Cps.unit("10,20").bind(CpsTest.parse)
    assert(m.bind(unit).run(identity) == m.run(identity))
  }

  // (m >>= f) >>= g ≡ m >>= ( \x -> (f x >>= g) )
  def testRule3(): Unit = {
    val f: String => Cps[BinaryValue, Int] = CpsTest.parse[Int]
    val g: BinaryValue => Cps[Int, Int] = CpsTest.sum[Int]
    val m: Cps[String, Int] = Cps.unit("10,20")

    val left = m.bind(f).bind(g).run(identity)
    val right = m.bind(value => f(value).bind(g)).run(identity)
    assert(left == right)
  }

call/cc

継続モナドを使って、カレント継続のキャプチャ機能をエミュレート出来るらしい。

どこに書いても良いけど、Cpsクラスに以下のメソッドを追加。

public class Cps<T, R> {
...
   public static <S, T, R> Cps<T, R> callCC(final Function<Function<T, Cps<S, R>>, Cps<T, R>> f) {
       return new Cps<T, R>(new Function<Function<T, R>, R>() {
           public R apply(final Function<T, R> k) {
               return f.apply(new Function<T, Cps<S, R>>() {
                   public Cps<S, R> apply(final T value) {
                       return new Cps<S, R>(new Function<Function<S, R>, R>() {
                           public R apply(Function<S, R> x) {
                               return k.apply(value);
                           }
                       });
                   }
               }).run(k);
           }
       });
   }
}
object Cps {
...
  def callCC[S, T, R](f: (T => Cps[S, R]) => Cps[T, R]): Cps[T, R] = {
    Cps[T, R](k => f(value => Cps[S, R](x => k(value))).run(k))
  }
}

これを使って、処理のチェーンの途中で処理を中断して戻ることが出来る。

例として、先ほど実装した「「10,20」のように数値二個がカンマで区切られた文字列を受け取り、それを足しあわせて、プリントする」の数値のパースが失敗した時に処理を中断してプリントしないようにしてみる。

エラーチェック入りバージョンのパース用関数を実装。exitという脱出用関数を引数に取れるようにしておく。

   private static <R> Function<String, Cps<BinaryValue, R>> safetyParse(
           final Function<String, Cps<BinaryValue, R>> exit) {
       return new Function<String, Cps<BinaryValue,R>>() {
           public Cps<BinaryValue, R> apply(String value) {
               String[] components = value.split(",");
               if (components.length < 2) {
                   return exit.apply("*** Error: too few numbers.");
               }
               Integer left;
               Integer right;
               try {
                   left = Integer.valueOf(components[0]);
                   right = Integer.valueOf(components[1]);
               } catch (NumberFormatException e) {
                   return exit.apply("*** Error: invalid number format: " + e.getMessage());
               }

               return Cps.unit(new BinaryValue(left, right));
           }
       };
  def safetyParse[R](exit: String => Cps[BinaryValue, R]): String => Cps[BinaryValue, R] = { value =>
    val components = value.split(",")
    if (components.size < 2) {
      exit("*** Error: too few numbers.")
    } else {
      try {
        val Array(left, right) = components.map(_.toInt)
        Cps.unit(BinaryValue(left, right))
      } catch {
        case e: NumberFormatException => exit("*** Error: invalid number format: " + e.getMessage())
      }
    }
  }

printHandlerを文字列を取れるように修正して、結果をStringで受け取れるようにする。それにあわせて、Integerの結果をプリント出来るようにformat処理を定義しておく。

       Function<String, Void> printResult = new Function<String, Void>() {
           public Void apply(String value) {
               System.out.println(value);
               return null;
           }
       };
       final Function<Integer, Cps<String, Void>> format =
           new Function<Integer, Cps<String,Void>>() {
               public Cps<String, Void> apply(Integer value) {
                   return Cps.unit("result = " + value);
               }
           };
  val printResult: String => Unit = println _
  val format: Int => Cps[String, Unit] = value => Cps.unit("result = " + value)

処理全体をCps.callCCで囲んで、途中で脱出出来るようにする。

      Cps.callCC(
           new Function<Function<String, Cps<BinaryValue, Void>>, Cps<String, Void>>() {
               public Cps<String, Void> apply(
                       final Function<String, Cps<BinaryValue, Void>> exit) {
                   // 引数として、脱出用の関数が渡ってくる。

                   return Cps.<String, Void>unit("1020")   // 「1020」から ※カンマ忘れ
                       .bind(safetyParse(exit))            // 数値二個を安全に読み取り
                       .bind(CpsTest.<Void>sum())          // 足しあわせ
                       .bind(format);                      // プリント用に整形
                   }
               }
           ).run(printResult);                             // プリント

実行結果

*** Error: too few numbers.

sum()やformatがスキップされ、safetyParse()中でexitを呼び出した際の引数がプリントされた。

なんか思っていたcall/ccと少し違う気がするけど、大域脱出用に使えることは分かった。

これでコルーチンとかも書けるらしいけどまだちょっと使い方分からない。分かったらそのうち書くかも。

  Cps.callCC(
    (exit: String => Cps[BinaryValue, Unit]) => Cps.unit("1020")
      .bind(safetyParse(exit))
      .bind(CpsTest.sum[Unit])
      .bind(format)
  ).run(printResult)

(おまけ) Scalaなら for式でしょ

@terazzo さんの記事は上記まで。この記事では基本的に Haskell などと同じように unitbind でモナドを表現していました。

でも Scala では mapflatMap でモナドを表現すると for式が使えてちょっとお得な感じになります。

bind == flatMap なので単純にリネームして、後は map を定義してみましょう。

case class Cps[T, R](run: (T => R) => R) {

  def flatMap[S](f: T => Cps[S, R]): Cps[S, R] = {
    Cps[S, R]({k: (S => R) => run(value => f(value).run(k))})
  }

  def map[S](f: T => S): Cps[S, R] = {
    Cps.unit(f) flatMap {f => this flatMap {v => Cps.unit(f(v))}}
  }

}
object Cps {

  def unit[T, R](value: T): Cps[T, R] = {
    val passthrough: (T => R) => R = makePassthrough(value)
    Cps(passthrough)
  }

  private def makePassthrough[T, R](value: T): (T => R) => R = {
    k: (T => R) => k(value)
  }

  def callCC[S, T, R](f: (T => Cps[S, R]) => Cps[T, R]): Cps[T, R] = {
    Cps[T, R](k => f(value => Cps[S, R](x => k(value))).run(k))
  }

}

そうすると以下のような感じで for式が使えます。やったね!

  Cps.callCC(
    (exit: String => Cps[BinaryValue, Unit]) => for {
      a <- safetyParse[Unit](exit)("1020")
      b <- CpsTest.sum.apply(a)
      c <- format(b)
    } yield c
  ).run(printResult)
@xuwei-k
Copy link

xuwei-k commented Dec 7, 2013

case class Arg(value: Int, value: acc)

case class Arg(value: Int, acc: Int)

の間違い?

@gakuzzzz
Copy link
Author

gakuzzzz commented Dec 8, 2013

お!その通りです。ありがとうございます。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment