Skip to content

Instantly share code, notes, and snippets.

@alexandrnikitin
Created May 1, 2017 18:46
Show Gist options
  • Save alexandrnikitin/7a5107b1ddf3d0c6a6eacb5954d5db66 to your computer and use it in GitHub Desktop.
Save alexandrnikitin/7a5107b1ddf3d0c6a6eacb5954d5db66 to your computer and use it in GitHub Desktop.
Aho-Corasick with some perf improvements
using System;
using System.Collections.Generic;
namespace Adform.AdServing.AhoCorasickTree.Sandbox.V7g
{
public class AhoCorasickTree
{
internal AhoCorasickTreeNode Root { get; set; }
public AhoCorasickTree(IEnumerable<string> keywords)
{
Root = new AhoCorasickTreeNode();
if (keywords != null)
{
foreach (var p in keywords)
{
AddPatternToTree(p);
}
SetFailureNodes();
}
}
public bool Contains(string text)
{
var currentNode = Root;
for (var i = 0; i < text.Length; i++)
{
while (true)
{
var node = currentNode.GetTransition(text[i]);
if (node == null)
{
currentNode = currentNode.Failure;
if (currentNode == Root)
{
break;
}
}
else
{
if (node.IsWord)
{
return true;
}
currentNode = node;
break;
}
}
}
return false;
}
public bool ContainsThatStart(string text)
{
return Contains(text, true);
}
private bool Contains(string text, bool onlyStarts)
{
var pointer = Root;
for (var i = 0; i < text.Length; i++)
{
AhoCorasickTreeNode transition = null;
while (transition == null)
{
transition = pointer.GetTransition(text[i]);
if (pointer == Root)
break;
if (transition == null)
pointer = pointer.Failure;
}
if (transition != null)
pointer = transition;
else if (onlyStarts)
return false;
if (pointer.Results.Count > 0)
return true;
}
return false;
}
public IEnumerable<string> FindAll(string text)
{
var pointer = Root;
foreach (var c in text)
{
var transition = GetTransition(c, ref pointer);
if (transition != null)
pointer = transition;
foreach (var result in pointer.Results)
yield return result;
}
}
private AhoCorasickTreeNode GetTransition(char c, ref AhoCorasickTreeNode pointer)
{
AhoCorasickTreeNode transition = null;
while (transition == null)
{
transition = pointer.GetTransition(c);
if (pointer == Root)
break;
if (transition == null)
pointer = pointer.Failure;
}
return transition;
}
private void SetFailureNodes()
{
var nodes = FailToRootNode();
FailUsingBFS(nodes);
Root.Failure = Root;
}
private void AddPatternToTree(string pattern)
{
var node = Root;
foreach (var c in pattern)
{
node = node.GetTransition(c)
?? node.AddTransition(c);
}
node.AddResult(pattern);
node.IsWord = true;
}
private List<AhoCorasickTreeNode> FailToRootNode()
{
var nodes = new List<AhoCorasickTreeNode>();
foreach (var node in Root.Transitions)
{
node.Failure = Root;
nodes.AddRange(node.Transitions);
}
return nodes;
}
private void FailUsingBFS(List<AhoCorasickTreeNode> nodes)
{
while (nodes.Count != 0)
{
var newNodes = new List<AhoCorasickTreeNode>();
foreach (var node in nodes)
{
var failure = node.ParentFailure;
var value = node.Value;
while (failure != null && !failure.ContainsTransition(value))
{
failure = failure.Failure;
}
if (failure == null)
{
node.Failure = Root;
}
else
{
node.Failure = failure.GetTransition(value);
node.AddResults(node.Failure.Results);
if (!node.IsWord)
{
node.IsWord = failure.IsWord;
}
}
newNodes.AddRange(node.Transitions);
}
nodes = newNodes;
}
}
public AhoCorasickTreeSlim BuildSlim()
{
SetOffsets();
var data = new List<byte>();
var queue = new Queue<AhoCorasickTreeNode>();
queue.Enqueue(Root);
while (queue.Count > 0)
{
var currentNode = queue.Dequeue();
if (currentNode._entries.Length == 0) continue;
data.Add((byte)currentNode._entries.Length);
data.AddRange(BitConverter.GetBytes((char)currentNode.Failure.Offset));
if (currentNode._entries.Length > 0)
{
foreach (var entry in currentNode._entries)
{
if (entry.Key != 0)
{
queue.Enqueue(entry.Value);
data.Add((byte)entry.Key);
data.AddRange(entry.Value.IsWord
? BitConverter.GetBytes((char)0)
: BitConverter.GetBytes((char)entry.Value.Offset));
}
else
{
data.AddRange(new byte[SizeOfKey]);
data.AddRange(new byte[SizeOfNode]);
}
}
}
}
return new AhoCorasickTreeSlim(data.ToArray());
}
private const int SizeOfSize = sizeof(byte);
private const int SizeOfFailure = sizeof(char);
private const int SizeOfKey = sizeof(byte);
private const int SizeOfNode = sizeof(char);
private void SetOffsets()
{
var roots = 0;
var elements = 0;
var queue = new Queue<AhoCorasickTreeNode>();
queue.Enqueue(Root);
while (queue.Count > 0)
{
var currentNode = queue.Dequeue();
if (currentNode._entries.Length == 0) continue;
currentNode.Offset = calcOffset(roots, elements);
roots++;
foreach (var entry in currentNode._entries)
{
if (entry.Key != 0) queue.Enqueue(entry.Value);
elements++;
}
}
}
private int calcOffset(int roots, int childs)
{
return roots * (SizeOfSize + SizeOfFailure) + childs * (SizeOfKey + SizeOfNode);
}
}
}
using System.Collections.Generic;
using System.Linq;
namespace Adform.AdServing.AhoCorasickTree.Sandbox.V7g
{
internal class AhoCorasickTreeNode
{
public char Value { get; private set; }
public AhoCorasickTreeNode Failure { get; set; }
public bool IsWord;
private readonly List<string> _results;
private readonly AhoCorasickTreeNode _parent;
public List<string> Results { get { return _results; } }
public AhoCorasickTreeNode ParentFailure { get { return _parent == null ? null : _parent.Failure; } }
public AhoCorasickTreeNode[] Transitions { get { return _entries.Where(x => x.Key != 0).Select(x => x.Value).ToArray(); } }
public int Offset { get; set; }
internal Entry[] _entries;
private int _size;
public AhoCorasickTreeNode() : this(null, ' ')
{
}
private AhoCorasickTreeNode(AhoCorasickTreeNode parent, char value)
{
Value = value;
_parent = parent;
_results = new List<string>();
_entries = new Entry[0];
}
public void AddResult(string result)
{
if (!_results.Contains(result))
{
_results.Add(result);
}
}
public void AddResults(IEnumerable<string> results)
{
foreach (var result in results)
{
AddResult(result);
}
}
public AhoCorasickTreeNode AddTransition(char c)
{
var node = new AhoCorasickTreeNode(this, c);
if (_size == 0) Resize();
while (true)
{
var ind = c & (_size - 1);
if (_entries[ind].Key != 0 && _entries[ind].Key != c)
{
Resize();
continue;
}
_entries[ind].Key = c;
_entries[ind].Value = node;
return node;
}
}
public AhoCorasickTreeNode GetTransition(char c)
{
if (_size == 0) return null;
var ind = c & (_size - 1);
var keyThere = _entries[ind].Key;
if (keyThere != 0 && (keyThere == c))
{
return _entries[ind].Value;
}
return null;
}
public bool ContainsTransition(char c)
{
return GetTransition(c) != null;
}
private void Resize()
{
var newSize = _entries.Length * 2;
if (newSize == 0) newSize = 1;
Resize(newSize);
}
private void Resize(int newSize)
{
var newEntries = new Entry[newSize];
for (var i = 0; i < _entries.Length; i++)
{
var key = _entries[i].Key;
var value = _entries[i].Value;
var ind = key & (newSize- 1);
if (newEntries[ind].Key != 0 && newEntries[ind].Key != key)
{
Resize(newSize * 2);
return;
}
newEntries[ind].Key = key;
newEntries[ind].Value = value;
}
_entries = newEntries;
_size = newSize;
}
}
internal struct Entry
{
public char Key;
public AhoCorasickTreeNode Value;
}
}
using System.Runtime.CompilerServices;
namespace Adform.AdServing.AhoCorasickTree.Sandbox.V7g
{
public class AhoCorasickTreeSlim
{
private readonly byte[] _data;
public AhoCorasickTreeSlim(byte[] data)
{
_data = data;
}
public byte[] Data => _data;
private const int SizeOfSize = sizeof(byte);
private const int SizeOfFailure = sizeof(char);
private const int SizeOfKey = sizeof(byte);
private const int SizeOfNode = sizeof(char);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe byte GetKey(byte* currentNodePtr, int ind)
{
return *(byte*)(currentNodePtr + SizeOfSize + SizeOfFailure + ind * (SizeOfKey + SizeOfNode));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe byte* GetNext(byte* b, byte* currentNodePtr, int ind)
{
return b + *(char*)(currentNodePtr + (SizeOfSize + SizeOfFailure + ind * (SizeOfKey + SizeOfNode) + SizeOfKey));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe byte* GetFailure(byte* b, byte* currentNodePtr)
{
return b + *(char*)(currentNodePtr + SizeOfSize);
}
public unsafe bool Contains(string text)
{
fixed (byte* b = _data)
fixed (char* p = text)
{
var len = text.Length * 2;
var currentNodePtr = b;
var cptr = p;
while (len > 0)
{
var c = *cptr;
cptr++;
len -= 2;
CheckFailure:
var size = *currentNodePtr;
var ind = c & (size - 1);
var key = GetKey(currentNodePtr, ind);
if (key == c)
{
currentNodePtr = GetNext(b, currentNodePtr, ind);
if (currentNodePtr == b) return true;
}
else
{
currentNodePtr = GetFailure(b, currentNodePtr);
if (currentNodePtr != b) goto CheckFailure;
}
}
}
return false;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment