Skip to content

Instantly share code, notes, and snippets.

@rikace
Created July 5, 2015 21:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rikace/28ed06547b043cd6e660 to your computer and use it in GitHub Desktop.
Save rikace/28ed06547b043cd6e660 to your computer and use it in GitHub Desktop.
Parallelizing Async Task with dependencies using DAG
namespace Async
module DAG =
open System
open System.Collections.Generic
open System.Threading
open Microsoft.FSharp.Collections
type MList<'a> = System.Collections.Generic.List<'a>
type TaskMessage =
| AddTask of int * TaskInfo
| QueueTask of TaskInfo
| ExecuteTasks
and TaskInfo =
{ Context : System.Threading.ExecutionContext
Edges : int array
Id : int
Task : unit -> unit
NumRemainingEdges : int option
Start : DateTimeOffset option
End : DateTimeOffset option }
type ParallelTasksDAG() =
let onTaskCompleted = new Event<TaskInfo>()
let verifyThatAllOperationsHaveBeenRegistered (tasks:Dictionary<int, TaskInfo>) =
let tasksNotRegistered =
tasks.Values
|> (Seq.collect (fun f -> f.Edges) >> set)
|> Seq.filter(tasks.ContainsKey >> not)
if tasksNotRegistered |> Seq.length > 0 then
let edgesMissing = tasksNotRegistered |> Seq.map (string) |> Seq.toArray
raise (InvalidOperationException(sprintf "Missing operation: %s" (String.Join(", ", edgesMissing))))
let verifyTopologicalSort(tasks:Dictionary<int, TaskInfo>) =
// Build up the dependencies graph
let tasksToFrom = new Dictionary<int, MList<int>>(tasks.Values.Count, HashIdentity.Structural)
let tasksFromTo = new Dictionary<int, MList<int>>(tasks.Values.Count, HashIdentity.Structural)
for op in tasks.Values do
// Note that op.Id depends on each of op.Edges
tasksToFrom.Add(op.Id, new MList<int>(op.Edges))
// Note that each of op.Dependencies is relied on by op.Id
for deptId in op.Edges do
let success, _ = tasksFromTo.TryGetValue(deptId)
if not <| success then tasksFromTo.Add(deptId, new MList<int>())
tasksFromTo.[deptId].Add(op.Id)
// Create the sorted list
let partialOrderingIds = new MList<int>(tasksToFrom.Count)
let iterationIds = new MList<int>(tasksToFrom.Count)
let rec buildOverallPartialOrderingIds() =
match tasksToFrom.Count with
| 0 -> Some(partialOrderingIds)
| _ -> iterationIds.Clear()
for item in tasksToFrom do
if item.Value.Count = 0 then
iterationIds.Add(item.Key)
let success, depIds = tasksFromTo.TryGetValue(item.Key)
if success = true then
// Remove all outbound edges
for depId in depIds do
tasksToFrom.[depId].Remove(item.Key) |> ignore
// If nothing was found to remove, there's no valid sort.
if iterationIds.Count = 0 then None
else
// Remove the found items from the dictionary and
// add them to the overall ordering
for id in iterationIds do
tasksToFrom.Remove(id) |> ignore
partialOrderingIds.AddRange(iterationIds)
buildOverallPartialOrderingIds()
buildOverallPartialOrderingIds()
let verifyThereAreNoCycles(operations:Dictionary<int, TaskInfo>) =
if verifyTopologicalSort(operations) = None then
raise (InvalidOperationException("Cycle detected"))
let nrd = function
| Some(n) -> Some(n - 1)
| None -> None
let rec getDependentOperation (dep : int list) (ops : Dictionary<int, TaskInfo>) acc =
match dep with
| [] -> acc
| h :: t -> ops.[h] <- { ops.[h] with NumRemainingEdges = nrd ops.[h].NumRemainingEdges }
match ops.[h].NumRemainingEdges.Value with
| 0 -> getDependentOperation t ops (ops.[h] :: acc)
| _ -> getDependentOperation t ops acc
let dagAgent =
let inbox = new MailboxProcessor<TaskMessage>(fun inbox ->
let rec loop (tasks : Dictionary<int, TaskInfo>)
(edges : Dictionary<int, int list>) = async {
let! msg = inbox.Receive()
match msg with
| ExecuteTasks ->
// Verify that all operations are registered
verifyThatAllOperationsHaveBeenRegistered(tasks)
// Verify no cycles
verifyThereAreNoCycles(tasks)
let dependenciesFromTo = new Dictionary<int, int list>()
let operations' = new Dictionary<int, TaskInfo>()
// Fill dependency data structures
for KeyValue(key, value) in tasks do
let operation' =
{ value with NumRemainingEdges = Some(value.Edges.Length) }
for from in operation'.Edges do
let exists, lstDependencies = dependenciesFromTo.TryGetValue(from)
if not <| exists then
dependenciesFromTo.Add(from, [ operation'.Id ])
else
dependenciesFromTo.[from] <- (operation'.Id :: lstDependencies)
operations'.Add(key, operation')
operations' |> Seq.filter (fun kv ->
match kv.Value.NumRemainingEdges with
| Some(n) when n = 0 -> true
| _ -> false)
|> Seq.iter (fun op -> inbox.Post(QueueTask(op.Value)))
return! loop operations' dependenciesFromTo
| QueueTask(op) ->
Async.Start <| async {
// Time and run the operation's delegate
let start' = DateTimeOffset.Now
match op.Context with
| null -> op.Task()
| ctx ->
ExecutionContext.Run(ctx.CreateCopy(),
(fun op -> let opCtx = (op :?> TaskInfo)
(opCtx.Task())), op)
let end' = DateTimeOffset.Now
// Raise the operation completed event
onTaskCompleted.Trigger { op with Start = Some(start')
End = Some(end') }
// Queue all the operations that depend on the completation
// of this one, and potentially launch newly available
let exists, lstDependencies = edges.TryGetValue(op.Id)
if exists && lstDependencies.Length > 0 then
let dependentOperation' = getDependentOperation lstDependencies tasks []
edges.Remove(op.Id) |> ignore
dependentOperation'
|> Seq.iter (fun nestedOp -> inbox.Post(QueueTask(nestedOp))) }
return! loop tasks edges
| AddTask(id, op) -> tasks.Add(id, op)
return! loop tasks edges
}
loop (new Dictionary<int, TaskInfo>(HashIdentity.Structural)) (new Dictionary<int, int list>(HashIdentity.Structural)))
inbox.Error |> Observable.add(fun ex -> printfn "Error : %s" ex.Message )
inbox.Start()
inbox
[<CLIEventAttribute>]
member this.OnTaskCompleted = onTaskCompleted.Publish
member this.ExecuteTasks() = dagAgent.Post ExecuteTasks
member this.AddTask(id, task, [<ParamArrayAttribute>] edges : int array) =
let data =
{ Context = ExecutionContext.Capture()
Edges = edges
Id = id
Task = task
NumRemainingEdges = None
Start = None
End = None }
dagAgent.Post(AddTask(id, data))
let acc1() = printfn "action 1"; System.Threading.Thread.Sleep 2000; //printfn "action 1 completed"
let acc2() = printfn "action 2"; System.Threading.Thread.Sleep 3000;// printfn "action 2 completed"
let acc3() = printfn "action 3"
let acc4() = printfn "action 4"; System.Threading.Thread.Sleep 5000; //printfn "action 4 completed"
let acc5() = printfn "action 5"
let acc6() = printfn "action 6"
let acc7() = printfn "action 7"
let acc8() = printfn "action 8"
let acc9() = printfn "action 9"
let acc10() = printfn "action 10"
let acc11() = printfn "action 11"
let acc12() = printfn "action 12"
let dagAsync = Async.DAG.ParallelTasksDAG()
dagAsync.OnTaskCompleted |> Observable.add(fun op -> System.Console.ForegroundColor <- ConsoleColor.Magenta
printfn "Completed %d" op.Id)
dagAsync.AddTask(1, acc1, 4,5)
dagAsync.AddTask(2, acc2, 5)
dagAsync.AddTask(3, acc3, 6, 5)
dagAsync.AddTask(4, acc4, 6)
dagAsync.AddTask(5, acc5, 7, 8)
dagAsync.AddTask(6, acc6, 7)
dagAsync.AddTask(7, acc7)
dagAsync.AddTask(8, acc8)
dagAsync.ExecuteTasks()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment