Skip to content

Instantly share code, notes, and snippets.

@chadselph
Last active January 6, 2017 06:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chadselph/f0e1559ecc2b178b83ab02dc50fc41ca to your computer and use it in GitHub Desktop.
Save chadselph/f0e1559ecc2b178b83ab02dc50fc41ca to your computer and use it in GitHub Desktop.
tiny slick-based distributed task queue prototype
import java.sql.Timestamp
import java.time.{Duration, Instant}
import slick.driver.JdbcProfile
import slick.profile.SqlProfile.ColumnOption.Nullable
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
/**
* Created by chad on 1/5/17.
*/
class TaskQueue[Driver <: JdbcProfile](tableName: String = "tasks",
val driver: Driver,
database: Driver#Backend#Database) {
import driver.api._
// type Instant = String
type TaskName = String
type ShouldRun = Option[TaskRun] => Boolean
/**
* Use this if you want your tasks to start every `duration` without caring whether or
* not the previous one has finished. Keep in mind, if your task takes longer than `duration`
* you run the risk of them piling them on top of each other.
* @param duration How long to wait after the previous worker finished before trying again
*/
case class HasNotStartedInLast(duration: Duration) extends ShouldRun {
override def apply(lastRun: Option[TaskRun]): Boolean = {
lastRun.fold(ifEmpty = true)(lastRun =>
wasMoreThanDurationAgo(duration, lastRun.startedAt))
}
}
/**
* You can use this class to avoid your tasks from stacking on top of each other, i.e. if the duration
* of the task takes longer than the wait time.
* They will only start when the previous runner has finished or we think they are dead.
* @param duration how long to wait after the previous worker finisher before running again.
* @param expiration how long to wait after the startedAt time before we decide to run again anyway.
*/
case class HasNotFinishedInLast(duration: Duration, expiration: Duration)
extends ShouldRun {
override def apply(lastRun: Option[TaskRun]): Boolean = lastRun match {
case None => true
case Some(TaskRun(_, _, _, started, None)) =>
wasMoreThanDurationAgo(expiration, started)
case Some(TaskRun(_, _, _, _, Some(finished))) =>
wasMoreThanDurationAgo(duration, finished)
}
}
sealed trait RunResult
case class StartRun(previousRun: Option[TaskRun], thisRun: TaskRun)
extends RunResult
case class SkipRun(previousRun: Option[TaskRun]) extends RunResult
case class InsertFailed(cause: Throwable) extends RunResult
case class TaskRun(name: String,
runId: Int,
worker: String,
startedAt: Instant = Instant.now,
finishedAt: Option[Instant] = None)
class TaskRuns(tag: Tag) extends Table[TaskRun](tag, tableName) {
implicit val instantColumnType: BaseColumnType[Instant] =
MappedColumnType.base[Instant, Timestamp](Timestamp.from, _.toInstant)
def name = column[TaskName]("name")
def runId = column[Int]("run_id")
def worker = column[String]("worker")
def startedAt = column[Instant]("started_at")
def finishedAt = column[Instant]("finished_at", Nullable)
def * =
(name, runId, worker, startedAt, finishedAt.?) <> (TaskRun.tupled, TaskRun.unapply)
def pk = primaryKey("pk_name_run_id", (name, runId))
}
val tasks = TableQuery[TaskRuns]
def markFinished(taskRun: TaskRun) =
tasks.update(taskRun.copy(finishedAt = Some(Instant.now())))
def startTaskIf[A](shouldRun: ShouldRun, taskName: TaskName, worker: String)(
doAction: (StartRun) => Future[A])(implicit ec: ExecutionContext) = {
val latestRunId = tasks.filter(_.name === taskName).map(_.runId).max
val latestRun = tasks
.filter(_.name === taskName)
.filter(_.runId === latestRunId)
.result
.headOption
val insertTaskRun = latestRun.map {
case last if !shouldRun(last) =>
SkipRun(last)
case None =>
StartRun(None, TaskRun(taskName, 1, worker))
case Some(last) =>
StartRun(Some(last), TaskRun(taskName, last.runId + 1, worker))
}.flatMap {
case r @ StartRun(_, thisRun) =>
(tasks += thisRun).asTry.map {
case Failure(ex) => InsertFailed(ex)
case _ => r
}
case r => DBIO.successful(r)
}
database.run(insertTaskRun).flatMap {
case sr @ StartRun(_, thisRun) =>
for {
_ <- doAction(sr).map(Success.apply).recover {
// convert the failure case into Success so we still mark the task as finished.
case ex => Failure(ex)
}
_ <- database.run(markFinished(thisRun))
} yield sr
case other => Future.successful(other)
}
}
/**
* Helper to see an [[Instant]] was more than [[Duration]] ago.
* i.e. "was 2017-01-01 13:13 more than 5 hours ago?"
*/
private def wasMoreThanDurationAgo(duration: Duration,
instant: Instant): Boolean = {
instant.plus(duration).isAfter(Instant.now())
}
}
import java.time.Duration
import slick.driver.H2Driver
import slick.driver.H2Driver.api._
import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
/**
* Created by chad on 1/5/17.
*/
object ExampleUsage extends App {
val db = Database.forURL("jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1")
val queue = new TaskQueue[H2Driver]("tasks", H2Driver, db)
val createDb = db.run(queue.tasks.schema.create)
import queue._
val done =
createDb.flatMap { _ =>
queue.startTaskIf(HasNotStartedInLast(Duration.ofMinutes(10)), "print-hello", "worker-1") { start =>
println(s"Last print was at $start.")
println("HELLO")
throw new Exception("SDF")
}
}
println(Await.result(done, 10.seconds))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment