Skip to content

Instantly share code, notes, and snippets.

@polytypic
Created April 29, 2021 07:22
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save polytypic/08929bfe060f9cafe0d45f2b4ebf9f38 to your computer and use it in GitHub Desktop.
Save polytypic/08929bfe060f9cafe0d45f2b4ebf9f38 to your computer and use it in GitHub Desktop.

Learning day: ReDiSL

Logistics team held a "learning day" on Thursday 2021-03-04 where we could spend one full day on a project or tutorial of our choosing. I worked on a little proof-of-concept project I dubbed ReDiSL or Redis DSL. This is a little writeup on the topic.

The motivation for working on the Redis DSL or ReDiSL was that it could make for a nice way to work with Redis especially when one is doing more complex operations with Redis that could be best done by running Lua scripts on Redis.

The old way of working with Lua scripts is to just write them down as strings

private[lib] val Renewed = -9
private[lib] val Claimed = -8

/** This should return `Renewed` or `Claimed` when we win and the other winner's TTL >= 0 when we lose. */
private val RenewScript = {
  val raceKey    = "KEYS[1]"
  val runnerUid  = "ARGV[1]"
  val newTTL     = "ARGV[2]"
  val ttlChannel = "ARGV[3]"
  s"""|local owner = redis.call("GET", $raceKey)
      |if owner == $runnerUid then
      |  redis.call("PEXPIRE", $raceKey, $newTTL)
      |  redis.call("PUBLISH", $ttlChannel, $newTTL)
      |  return $Renewed
      |elseif owner then
      |  return redis.call("PTTL", $raceKey)
      |else
      |  redis.call("PSETEX", $raceKey, $newTTL, $runnerUid)
      |  redis.call("PUBLISH", $ttlChannel, $newTTL)
      |  return $Claimed
      |end""".stripMargin
}

and then one needs to load the script to Redis

private var loadedSHA: Option[String] = None
private def renewSHA: String =
  loadedSHA.getOrElse {
    loadedSHA = Option(jedisPool.withJedisClient(_.scriptLoad(RenewScript)))
    renewSHA
  }

and probably one also wants to have some wrapper for calling the script

private[lib] def renewInternal(
    raceId: String,
    runnerUid: String,
    newTTL: Long,
    ttlChannel: String
  ): Long =
  jedisPool
    .withJedisClient(_.evalsha(renewSHA, List(s"lib.LeaderElection:$raceId").asJava, List(runnerUid, newTTL.toString, ttlChannel).asJava))
    .asInstanceOf[Long]

so doing all of that manually isn't really ideal.

The idea being that ultimately, with a suitable Scala->Redis macro, one could just define the wrapper function roughly like

private val renewInterval = ReDiSL.command((
    raceKey: Exp[Key[String]],
    ttlChannel: Exp[Channel[IntegerType]],
    runnerUid: Exp[String],
    newTTL: Exp[IntegerType]
  ) => lua {
    val owner = GET(raceKey)
    if (owner === runnerUid) {
      PEXPIRE(raceKey, newTTL)
      PUBLISH(ttlChannel, newTTL)
      Renewed
    }
    else if (Nil ~== owner) {
      PTTL(raceKey)
    }
    else {
      PSETEX(raceKey, newTTL, runnerUid)
      PUBLISH(ttlChannel, newTTL)
      Claimed
    }
  }
)

and that would give you a function that you could call with a Redis handle and the other necessary parameters. The advantages here would be that the Scala code snippet (and the generated Lua) would be type checked and it would also ensure that parameters are passed correctly to the script. It would also likely be possible to do some optimizations on the generated code and, in case of really simple commands, avoid generating Lua and just execute a Redis command or few Redis commands.

For the learning day I set my goals low enough that I could conceivably get something ready by the end of the day. So, I first started by drafting a HOAS and GADT based AST for the language:

sealed trait Exp[+T] {
  def === [U >: T](that: Exp[U]): Exp[Boolean] = Equals(this, that)
  def ~== [U >: T](that: Exp[U]): Exp[Boolean] = EqualsNot(this, that)
}

final case class GET[T](key: Exp[Key[T]]) extends Exp[T]
final case class PEXPIRE[T](key: Exp[Key[T]], ttlMillis: Exp[IntegerType]) extends Exp[IntegerType]
final case class PTTL[T](key: Exp[Key[T]]) extends Exp[IntegerType]
final case class PUBLISH[T](channel: Exp[Channel[T]], message: Exp[T]) extends Exp[IntegerType]
final case class PSETEX[T](key: Exp[Key[T]], ttlMillis: Exp[IntegerType], value: Exp[T]) extends Exp[String]

final case class IfElse[T](condition: Exp[Boolean], consequent: Exp[T], alternative: Exp[T]) extends Exp[T]

final case class Let[T, U](body: Exp[T] => Exp[U], value: Exp[T]) extends Exp[U]
final case class Var[T](name: String) extends Exp[T]

final case class Equals[T](lhs: Exp[T], rhs: Exp[T]) extends Exp[Boolean]
final case class EqualsNot[T](lhs: Exp[T], rhs: Exp[T]) extends Exp[Boolean]

final case object Nil extends Exp[Any]
final case class Bool(value: Boolean) extends Exp[Boolean]
final case class Integer(value: scala.Long) extends Exp[IntegerType]
final case class Floating(value: scala.Double) extends Exp[FloatingType]
final case class Str(value: String) extends Exp[String]

final case class AndThen[T, U](first: Exp[T], second: Exp[U]) extends Exp[U]

One technicality in getting to the goal outline above is to map from the multi parameter lambda

private val renewInterval = ReDiSL.command((
    raceKey: Exp[Key[String]],
    ttlChannel: Exp[Channel[IntegerType]],
    runnerUid: Exp[String],
    newTTL: Exp[IntegerType]
  ) => lua {

to something that can be manipulated more easily for which I drafted a bunch of type class machinery:

trait ArgTuple[Exps, Vals] {
  def exps: Exps
}

trait Arg[Val] {
  def exp(keys: Int, args: Int): (Exp[Val], (Int, Int))
  def exps(state: (Int, Int)) = (exp _).tupled(state)
}

implicit def ArgKey[T]: Arg[Key[T]] = new Arg[Key[T]] {
  def exp(keys: Int, args: Int) = (Var[Key[T]](s"KEYS[$keys]"), (keys + 1, args))
}

implicit def ArgChannel[T]: Arg[Channel[T]] = new Arg[Channel[T]] {
  def exp(keys: Int, args: Int) = (Var[Channel[T]](s"ARGS[$args]"), (keys, args + 1))
}

implicit val ArgString: Arg[String] = new Arg[String] {
  def exp(keys: Int, args: Int) = (Var[String](s"ARGS[$args]"), (keys, args + 1))
}

implicit val ArgLong: Arg[IntegerType] = new Arg[IntegerType] {
  def exp(keys: Int, args: Int) = (Var[IntegerType](s"ARGS[$args]"), (keys, args + 1))
}

implicit val ArgDouble: Arg[FloatingType] = new Arg[FloatingType] {
  def exp(keys: Int, args: Int) = (Var[FloatingType](s"ARGS[$args]"), (keys, args + 1))
}

implicit def ArgTuple4[T1, T2, T3, T4](
    implicit
    a1: Arg[T1],
    a2: Arg[T2],
    a3: Arg[T3],
    a4: Arg[T4]
  ): ArgTuple[(Exp[T1], Exp[T2], Exp[T3], Exp[T4]), (T1, T2, T3, T4)] =
  new ArgTuple[(Exp[T1], Exp[T2], Exp[T3], Exp[T4]), (T1, T2, T3, T4)] {

    def exps = {
      val (e1, s2) = a1.exp(1, 1)
      val (e2, s3) = a2.exps(s2)
      val (e3, s4) = a3.exps(s3)
      val (e4, _)  = a4.exps(s4)
      (e1, e2, e3, e4)
    }
  }

The drafted machinery only works for 4 arguments given as a tuple, but it should be easy to extend from there.

After getting the arguments sorted out, I wrote a simple compiler from the AST to Lua:

private def wrapExprAsStmt(isTail: Boolean, indent: Int, exp: String) =
  " " * indent + (if (isTail) "return " else "") + exp + "\n"

private def toLuaStmt[T](
    isTail: Boolean,
    indent: Int,
    nextId: AtomicInteger,
    exp: Exp[T]
  ): String =
  exp match {
    case (_: GET[_] | _: PEXPIRE[_] | _: PTTL[_] | _: PUBLISH[_] | _: PSETEX[_] | Nil | _: Bool | _: Integer | _: Floating | _: Str | _: Equals[_] |
        _: EqualsNot[_] | _: Var[_]) =>
      wrapExprAsStmt(isTail, indent, toLuaExpr(nextId, exp))

    case e: Let[u, t] => {
      val variable = Var[u]("id_" + nextId.getAndIncrement())

      " " * indent + "local " + toLuaExpr(nextId, variable) + " = " + toLuaExpr(nextId, e.value) + "\n" +
        toLuaStmt(isTail, indent, nextId, e.body(variable))
    }

    case e: IfElse[t] =>
      " " * indent + "if " + toLuaExpr(nextId, e.condition) + " then\n" +
        toLuaStmt(isTail, indent + 2, nextId, e.consequent) +
        " " * indent + "else\n" +
        toLuaStmt(isTail, indent + 2, nextId, e.alternative) +
        " " * indent + "end\n"

    case e: AndThen[t, u] => toLuaStmt(false, indent, nextId, e.first) + toLuaStmt(isTail, indent, nextId, e.second)
  }

private def toLuaExpr[T](nextId: AtomicInteger, exp: Exp[T]): String =
  exp match {
    case e: GET[_]     => s"""redis.call("GET", ${toLuaExpr(nextId, e.key)})"""
    case e: PEXPIRE[_] => s"""redis.call("PEXPIRE", ${toLuaExpr(nextId, e.key)}, ${toLuaExpr(nextId, e.ttlMillis)})"""
    case e: PTTL[_]    => s"""redis.call("PTTL", ${toLuaExpr(nextId, e.key)})"""
    case e: PUBLISH[_] => s"""redis.call("PUBLISH", ${toLuaExpr(nextId, e.channel)}, ${toLuaExpr(nextId, e.message)})"""
    case e: PSETEX[_]  => s"""redis.call("PSETEX", ${toLuaExpr(nextId, e.key)}, ${toLuaExpr(nextId, e.ttlMillis)}, ${toLuaExpr(nextId, e.value)})"""

    case e: Equals[t]    => s"""(${toLuaExpr(nextId, e.lhs)} == ${toLuaExpr(nextId, e.rhs)})"""
    case e: EqualsNot[t] => s"""(${toLuaExpr(nextId, e.lhs)} ~= ${toLuaExpr(nextId, e.rhs)})"""

    case Nil         => "nil"
    case e: Bool     => e.value.toString
    case e: Integer  => e.value.toString
    case e: Floating => e.value.toString
    case e: Str      => "[====[" + e.value + "]====]" // TODO: Format string literals properly

    case e: Var[_] => e.name

    case (_: Let[_, _] | _: IfElse[_] | _: AndThen[_, _]) =>
      "(function ()\n" + toLuaStmt(true, 2, nextId, exp) + "end)()"
  }

def toLua[Exps, Vals, Res](script: Exps => Exp[Res])(implicit args: ArgTuple[Exps, Vals]): String =
  toLuaStmt(true, 0, new AtomicInteger(1), script(args.exps))

The compiler is very simple. The one "optimization" it does is that it switches between generating statements and expressions. The AST itself essentially only has expressions. However, some of the expressions, like the HOAS Let, are better compiled to Lua statements than expressions. So, the compiler eagerly switches between generating statements and expressions and when a statement is required in an expression context, a Lua function is generated inside of which it is possible to generate statements. In the tested use case this is sufficient to generate exactly the desired code (with no unnecessary Lua functions):

proc mustBe """|local id_1 = redis.call("GET", KEYS[1])
               |if (id_1 == ARGS[2]) then
               |  redis.call("PEXPIRE", KEYS[1], ARGS[3])
               |  redis.call("PUBLISH", ARGS[1], ARGS[3])
               |  return -9
               |else
               |  if (nil ~= id_1) then
               |    return redis.call("PTTL", KEYS[1])
               |  else
               |    redis.call("PSETEX", KEYS[1], ARGS[3], ARGS[2])
               |    redis.call("PUBLISH", ARGS[1], ARGS[3])
               |    return -8
               |  end
               |end
               |""".stripMargin

Writing the ReDiSL AST by hand

let(GET(raceKey)) { owner =>
  IfElse(
    owner === runnerUid,
    AndThen(PEXPIRE(raceKey, newTTL), AndThen(PUBLISH(ttlChannel, newTTL), Renewed)),
    IfElse(
      Nil ~== owner,
      PTTL(raceKey),
      AndThen(
        PSETEX(raceKey, newTTL, runnerUid),
        AndThen(
          PUBLISH(ttlChannel, newTTL),
          Claimed
        )
      ))
  )
}

is not ideal. It would be possible to improve on that by using some custom operators, for example, but I wanted to see if I could write a lua macro that generates the AST from ordinary looking Scala code using Scala's own blocks and if-expressions

def action(
    raceKey: Exp[Key[String]],
    ttlChannel: Exp[Channel[IntegerType]],
    runnerUid: Exp[String],
    newTTL: Exp[IntegerType]
  ) = lua {
  val owner = GET(raceKey)
  if (owner === runnerUid) {
    PEXPIRE(raceKey, newTTL)
    PUBLISH(ttlChannel, newTTL)
    Renewed
  }
  else if (Nil ~== owner) {
    PTTL(raceKey)
  }
  else {
    PSETEX(raceKey, newTTL, runnerUid)
    PUBLISH(ttlChannel, newTTL)
    Claimed
  }
}

I didn't manage to get that working quite in time for the demo, but I got it working on the same day. The key problem I ran into is that the Scala expression tree given to Scala macros contains some type information, which needs to be erased (by calling untypecheck) in case of doing the kind of rewrites that I was doing in the Lua macro implementation:

object Syntax {
  import scala.language.experimental.macros
  import scala.reflect.macros.blackbox.Context
  import scala.language.implicitConversions

  implicit def ifKludge(@unused e: Exp[Boolean]) = true

  def lua[T](expr: Exp[T]): Exp[T] = macro luaMacro

  def luaMacro(c: Context)(expr: c.Tree): c.Tree = {
    import c.universe._

    @nowarn
    def rewrite(expr: c.Tree): c.Tree =
      expr match {
        case q"if (ReDiSL.Syntax.ifKludge($condition)) { ..$consequent } else { ..$alternative }" =>
          q"ReDiSL.IfElse.apply(${rewrite(condition)}, ${rewrite(q"{ ..$consequent }")}, ${rewrite(q"{ ..$alternative }")})"
        case q"{ ..$stmts }" if stmts.length > 1 => {
          stmts match {
            case q"val $name = $expr" :: stmts =>
              val body = rewrite(q"{ ..${stmts} }")
              q"ReDiSL.let($expr)(xxx => {val $name = xxx; $body})"
            case stmt :: stmts =>
              q"ReDiSL.AndThen.apply(${rewrite(stmt)}, ${rewrite(q"{ ..$stmts }")})"
            case scala.Nil =>
              throw new Exception("impossible")
          }
        }
        case _ => expr
      }

    rewrite(c.untypecheck(expr))
  }
}

You may have noticed the ifKludge

implicit def ifKludge(@unused e: Exp[Boolean]) = true

This was needed as the ReDiSL expressions have types of the form Exp[_] and an ordinary Scala conditional, for example, requires the type Boolean rather than Exp[Boolean]. When using a Scala macro, the code passed to the macro is type checked as such before the macro rewrites the code. So, having learned this, I think that for something like ReDiSL one would want to create a combinator library with unwrapped types (i.e. T instead of Exp[T]. Those would work as a "skin" that the Scala gets to type check before macro rewrite. The macro rewrite could then generate the actual code that uses wrapped types (i.e. Exp[T] instead of T) inside.

So, could and should something like this be built? Well, I think this is definitely doable. My proof-of-concept implementation is just that. It would need a lot more code to cover more of the Redis API and it would also need some more plumbing to make it as smooth as possible to use (in the way outlined earlier). But it would definitely be doable. On the other hand, if we would want something like this, then I would probably recommend building that as an open source library. Maintaining it just for ourselves would likely be waste. Would it help? Possibly, yes. I think given a robust enough implementation of this kind of ReDiSL it could certainly make it considerably nicer to interface with Redis. It would likely make it easier to avoid network round trips and allow us to optimize Redis usage. It is of course possible to do all of that manually.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment