Skip to content

Instantly share code, notes, and snippets.

@pmunin
Last active May 15, 2017 00:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pmunin/e45692df0b0deee81008aa360e9c612c to your computer and use it in GitHub Desktop.
Save pmunin/e45692df0b0deee81008aa360e9c612c to your computer and use it in GitHub Desktop.
GraphAggregateUtils - Allows to aggregate node values of acyclic directed graphs, without double counting shared children nodes
using DictionaryUtils;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TestGraphUtils;
using Xunit;
namespace GraphAggregateUtils
{
public class GraphAggregateTests
{
decimal Cost(TestGraphNode node)
{
return node.Data.GetOrAdd("Cost", _ => 0m);
}
TestGraphNode Cost(TestGraphNode node, decimal value)
{
node.Data["Cost"] = value;
return node;
}
TestGraphNode GenerateTestGraph()
{
var v0 = new TestGraphNode() { Name = "0" };
var v3 = null as TestGraphNode;
v0.Link(v1 => {
v1.Name = "1";
Cost(v1, 1);
v1.Link(_v3 =>
{
v3 = _v3;
v3.Name = "3";
Cost(v3, 3);
v3.Link(v5 =>
{
v5.Name = "5";
Cost(v5, 5);
})
.Link(v6 =>
{
v6.Name = "6";
Cost(v6, 6);
});
})
;
})
.Link(v2 => {
v2.Name = "2";
Cost(v2, 2);
v2
.Link(v3)
.Link(v4 => {
v4.Name = "4";
Cost(v4, 4);
})
;
})
;
return v0;
}
[Fact]
public void Test1()
{
var v0 = GenerateTestGraph();
var aggregate = GraphAggregateUtils.Aggregate(
v0,
n=>Cost(n),
(v1,v2)=>v1+v2,
n=>n.Links
);
Assert.True(aggregate.AggregatedValue == 15 + 2 + 4);
}
[Fact]
public void TestLazy1()
{
var v0 = GenerateTestGraph();
var aggregate = TestLazyDFS(v0) as GraphNodeAggregationLazy<decimal>;
var res = aggregate.GetTotalValue();
Assert.True(res == 15 + 2 + 4);
}
private GraphNodeAggregationLazy<decimal> TestLazyDFS(TestGraphNode node)
{
return node.Data.GetOrAdd("aggregation", _ => {
var agg = new GraphNodeAggregationLazy<decimal>((v1,v2)=>v1+v2);
agg.LocalValue = Cost(node);
foreach (var childNode in node.Links)
agg.AddChild(TestLazyDFS(childNode));
return agg;
});
}
}
}
//Latest version here: https://gist.github.com/e45692df0b0deee81008aa360e9c612c.git
using DictionaryUtils;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace GraphAggregateUtils
{
/// <summary>
/// Allows to aggregate node values of acyclic directed graphs, without double counting shared children nodes
/// </summary>
public static class GraphAggregateUtils
{
public static GraphNodeAggregation<TNode, TValue> Aggregate<TNode, TValue>(
TNode rootNode
, Func<TNode, TValue> getNodeValue
, Func<TValue, TValue, TValue> aggregateValue
, Func<TNode, IEnumerable<TNode>> getNodeLinks
, IDictionary<TNode, GraphNodeAggregation<TNode, TValue>> preAggregatedNodes = null
)
{
var args = new GraphNodeAggregation<TNode, TValue>.Args()
{
Node = rootNode,
AggregateValue = aggregateValue,
GetLinkedNodes = getNodeLinks,
GetNodeValue = getNodeValue,
AggregationByNode = preAggregatedNodes
};
return GetOrAddAggregation(args);
}
public static GraphNodeAggregation<TNode, TValue> GetOrAddAggregation<TNode, TValue>(GraphNodeAggregation<TNode, TValue>.Args args)
{
if (args.AggregationByNode == null)
args.AggregationByNode = new Dictionary<TNode, GraphNodeAggregation<TNode, TValue>>();
if (args.VisitedNodes == null)
args.VisitedNodes = new HashSet<TNode>();
return args.AggregationByNode.GetOrAdd(args.Node,node=> new GraphNodeAggregation<TNode, TValue>(args));
}
internal static void AddRange<T>(HashSet<T> hashSet, IEnumerable<T> itemsToAdd)
{
foreach (var item in itemsToAdd)
{
hashSet.Add(item);
}
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace GraphAggregateUtils
{
public class GraphNodeAggregation<TNode, TValue>
{
public GraphNodeAggregation(Args args)
{
var res = this;
var node = args.Node;
if (args.AggregationByNode != null && args.AggregationByNode.TryGetValue(node, out var existingAgg))
throw new InvalidOperationException("This node is already aggregated");
res.Node = node;
res.LocalValue = args.GetNodeValue(node);
res.AggregatedValue = res.LocalValue;
foreach (var childNode in args.GetLinkedNodes(node) ?? Enumerable.Empty<TNode>())
{
var childArgs = new GraphNodeAggregation<TNode, TValue>.Args()
{
AggregateValue = args.AggregateValue,
AggregationByNode = args.AggregationByNode,
GetLinkedNodes = args.GetLinkedNodes,
GetNodeValue = args.GetNodeValue,
Node = childNode
};
var childNodeAggregation = GraphAggregateUtils.GetOrAddAggregation(childArgs);
res.AggregateChild(childNodeAggregation, args, true);
}
}
public void AggregateChild(GraphNodeAggregation<TNode, TValue> childAggregation, Args args, bool isDirectChild)
{
var childNode = childAggregation.Node;
var nodeAggr = this;
if (nodeAggr.AllChildrenAggregates.Contains(childAggregation))
return;
if (!nodeAggr.AllChildrenAggregates.Overlaps(childAggregation.AllChildrenAggregates))
{
nodeAggr.AggregatedValue = args.AggregateValue(nodeAggr.AggregatedValue, childAggregation.AggregatedValue);
}
else//Has Overlaps
{
//adding childLocalValue
//and then aggregate recursively grandchildren
nodeAggr.AggregatedValue = args.AggregateValue(nodeAggr.AggregatedValue, childAggregation.LocalValue);
foreach (var grandChildAgg in childAggregation.DirectChildrenAggregates)
nodeAggr.AggregateChild(grandChildAgg, args, false);
}
//nodeAggr.AllChildrenNodes.Add(childNode);
//if(isDirectChild) nodeAggr.DirectChildrenNodes.Add(childNode);
GraphAggregateUtils.AddRange(nodeAggr.AllChildrenAggregates
, childAggregation.AllChildrenAggregates.Prepend(childAggregation));
if (isDirectChild) nodeAggr.DirectChildrenAggregates.Add(childAggregation);
}
public TNode Node;
public TValue LocalValue;
public TValue AggregatedValue;
public HashSet<GraphNodeAggregation<TNode, TValue>> DirectChildrenAggregates = new HashSet<GraphNodeAggregation<TNode, TValue>>();
public HashSet<GraphNodeAggregation<TNode, TValue>> AllChildrenAggregates = new HashSet<GraphNodeAggregation<TNode, TValue>>();
public class Args
{
public IDictionary<TNode, GraphNodeAggregation<TNode, TValue>> AggregationByNode;
public TNode Node;
public Func<TNode, IEnumerable<TNode>> GetLinkedNodes;
public Func<TValue, TValue, TValue> AggregateValue;
public Func<TNode, TValue> GetNodeValue;
/// <summary>
/// Required for Acyclic check
/// </summary>
public HashSet<TNode> VisitedNodes;
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace GraphAggregateUtils
{
public abstract partial class GraphNodeAggregationLazyBase<TAggregatableValue>
{
protected abstract TAggregatableValue AggregateValues(TAggregatableValue value1, TAggregatableValue value2);
Dictionary<object, object> data = null;
public IDictionary<object, object> Data
{
get
{
return data ?? (data = new Dictionary<object, object>());
}
}
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue>> Children => directChildren;
TAggregatableValue localValue;
public TAggregatableValue LocalValue { get { return localValue; } set { localValue = value; OnTotalChanged(); } }
protected void OnTotalChanged()
{
totalAggregatedLazy = null;
}
HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> directChildren
= new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>>();
public void AddChild(GraphNodeAggregationLazyBase<TAggregatableValue> child)
{
directChildren.Add(child);
OnTotalChanged();
}
public void RemoveChild(GraphNodeAggregationLazyBase<TAggregatableValue> child)
{
directChildren.Remove(child);
OnTotalChanged();
}
Lazy<(TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren)> totalAggregatedLazy = null;
protected Lazy<(TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren)> AggregatedLazy
{
get
{
return totalAggregatedLazy ??
(
totalAggregatedLazy
= new Lazy<(TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren)>(CalculateAggregated)
);
}
}
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue>> GetAllChildren()
{
return AggregatedLazy.Value.allChildren;
}
public TAggregatableValue GetTotalValue()
{
return AggregatedLazy.Value.totalValue;
}
protected (TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren) CalculateAggregated()
{
var res = (totalValue: this.LocalValue, allChildren: new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>>());
foreach (var child in directChildren)
{
AggregateAppend(child, ref res);
}
return res;
}
protected void AggregateAppend(GraphNodeAggregationLazyBase<TAggregatableValue> childToAggregate, ref (TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren) aggregateResult)
{
if (aggregateResult.allChildren.Contains(childToAggregate))
return;
var hasOverlaps = aggregateResult.allChildren.Overlaps(childToAggregate.GetAllChildren());
if (!hasOverlaps)
aggregateResult.totalValue = AggregateValues(aggregateResult.totalValue, childToAggregate.GetTotalValue());
else //Has overlaps
{
//adding childLocalValue
//and then aggregate recursively grandchildren
aggregateResult.totalValue = AggregateValues(
aggregateResult.totalValue
, childToAggregate.LocalValue
);
foreach (var grandChildAgg in childToAggregate.Children)
AggregateAppend(grandChildAgg, ref aggregateResult);
}
GraphAggregateUtils.AddRange(aggregateResult.allChildren
, childToAggregate.GetAllChildren().Prepend(childToAggregate));
}
}
public partial class GraphNodeAggregationLazy<TAggregatableValue>
: GraphNodeAggregationLazyBase<TAggregatableValue>
{
public GraphNodeAggregationLazy(Func<TAggregatableValue, TAggregatableValue, TAggregatableValue> aggregate)
{
this.AggregateDelegate = aggregate;
}
public Func<TAggregatableValue, TAggregatableValue, TAggregatableValue> AggregateDelegate { get; private set; }
protected override TAggregatableValue AggregateValues(TAggregatableValue value1, TAggregatableValue value2)
{
return AggregateDelegate(value1, value2);
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace GraphAggregateUtils
{
public abstract partial class GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>
{
protected abstract void AggregateValues(ref TAccumulate accumulatorMutable, TAggregatableValue valueImmutable);
protected abstract void AggregateAccumulators(ref TAccumulate accumulatorMutable, TAccumulate accumulatorImmutable);
protected virtual TAccumulate AggregateSeed()
{
return default(TAccumulate);
}
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> Children => directChildren;
TAggregatableValue localValue;
public TAggregatableValue LocalValue { get { return localValue; } set { localValue = value; OnTotalChanged(); } }
protected void OnTotalChanged()
{
totalAggregatedLazy = null;
}
HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>> directChildren
= new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>>();
public void AddChild(GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate> child)
{
directChildren.Add(child);
OnTotalChanged();
}
public void RemoveChild(GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate> child)
{
directChildren.Remove(child);
OnTotalChanged();
}
Lazy<(TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>> allChildren)> totalAggregatedLazy = null;
protected Lazy<(TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>> allChildren)> AggregatedLazy
{
get
{
return totalAggregatedLazy ??
(
totalAggregatedLazy
= new Lazy<(TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> allChildren)>(CalculateAggregated)
);
}
}
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> GetAllChildren()
{
return AggregatedLazy.Value.allChildren;
}
public TAccumulate GetTotalValue()
{
return AggregatedLazy.Value.totalValue;
}
protected (TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> allChildren) CalculateAggregated()
{
var res =
(
totalValue: AggregateSeed()
, allChildren: new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>>()
);
AggregateValues(ref res.totalValue, this.LocalValue);
foreach (var child in directChildren)
{
AggregateAppend(child, ref res);
}
return res;
}
protected void AggregateAppend(GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate> childToAggregate
, ref (TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> allChildren) aggregateResult
)
{
if (aggregateResult.allChildren.Contains(childToAggregate))
return;
var hasOverlaps = aggregateResult.allChildren.Overlaps(childToAggregate.GetAllChildren());
if (!hasOverlaps)
AggregateAccumulators(ref aggregateResult.totalValue, childToAggregate.GetTotalValue());
else //Has overlaps
{
//adding childLocalValue
//and then aggregate recursively grandchildren
AggregateValues(ref aggregateResult.totalValue , childToAggregate.LocalValue);
foreach (var grandChildAgg in childToAggregate.Children)
AggregateAppend(grandChildAgg, ref aggregateResult);
}
GraphAggregateUtils.AddRange(aggregateResult.allChildren
, childToAggregate.GetAllChildren().Prepend(childToAggregate));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment