Created
May 1, 2017 18:46
-
-
Save alexandrnikitin/7a5107b1ddf3d0c6a6eacb5954d5db66 to your computer and use it in GitHub Desktop.
Aho-Corasick with some perf improvements
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
namespace Adform.AdServing.AhoCorasickTree.Sandbox.V7g | |
{ | |
public class AhoCorasickTree | |
{ | |
internal AhoCorasickTreeNode Root { get; set; } | |
public AhoCorasickTree(IEnumerable<string> keywords) | |
{ | |
Root = new AhoCorasickTreeNode(); | |
if (keywords != null) | |
{ | |
foreach (var p in keywords) | |
{ | |
AddPatternToTree(p); | |
} | |
SetFailureNodes(); | |
} | |
} | |
public bool Contains(string text) | |
{ | |
var currentNode = Root; | |
for (var i = 0; i < text.Length; i++) | |
{ | |
while (true) | |
{ | |
var node = currentNode.GetTransition(text[i]); | |
if (node == null) | |
{ | |
currentNode = currentNode.Failure; | |
if (currentNode == Root) | |
{ | |
break; | |
} | |
} | |
else | |
{ | |
if (node.IsWord) | |
{ | |
return true; | |
} | |
currentNode = node; | |
break; | |
} | |
} | |
} | |
return false; | |
} | |
public bool ContainsThatStart(string text) | |
{ | |
return Contains(text, true); | |
} | |
private bool Contains(string text, bool onlyStarts) | |
{ | |
var pointer = Root; | |
for (var i = 0; i < text.Length; i++) | |
{ | |
AhoCorasickTreeNode transition = null; | |
while (transition == null) | |
{ | |
transition = pointer.GetTransition(text[i]); | |
if (pointer == Root) | |
break; | |
if (transition == null) | |
pointer = pointer.Failure; | |
} | |
if (transition != null) | |
pointer = transition; | |
else if (onlyStarts) | |
return false; | |
if (pointer.Results.Count > 0) | |
return true; | |
} | |
return false; | |
} | |
public IEnumerable<string> FindAll(string text) | |
{ | |
var pointer = Root; | |
foreach (var c in text) | |
{ | |
var transition = GetTransition(c, ref pointer); | |
if (transition != null) | |
pointer = transition; | |
foreach (var result in pointer.Results) | |
yield return result; | |
} | |
} | |
private AhoCorasickTreeNode GetTransition(char c, ref AhoCorasickTreeNode pointer) | |
{ | |
AhoCorasickTreeNode transition = null; | |
while (transition == null) | |
{ | |
transition = pointer.GetTransition(c); | |
if (pointer == Root) | |
break; | |
if (transition == null) | |
pointer = pointer.Failure; | |
} | |
return transition; | |
} | |
private void SetFailureNodes() | |
{ | |
var nodes = FailToRootNode(); | |
FailUsingBFS(nodes); | |
Root.Failure = Root; | |
} | |
private void AddPatternToTree(string pattern) | |
{ | |
var node = Root; | |
foreach (var c in pattern) | |
{ | |
node = node.GetTransition(c) | |
?? node.AddTransition(c); | |
} | |
node.AddResult(pattern); | |
node.IsWord = true; | |
} | |
private List<AhoCorasickTreeNode> FailToRootNode() | |
{ | |
var nodes = new List<AhoCorasickTreeNode>(); | |
foreach (var node in Root.Transitions) | |
{ | |
node.Failure = Root; | |
nodes.AddRange(node.Transitions); | |
} | |
return nodes; | |
} | |
private void FailUsingBFS(List<AhoCorasickTreeNode> nodes) | |
{ | |
while (nodes.Count != 0) | |
{ | |
var newNodes = new List<AhoCorasickTreeNode>(); | |
foreach (var node in nodes) | |
{ | |
var failure = node.ParentFailure; | |
var value = node.Value; | |
while (failure != null && !failure.ContainsTransition(value)) | |
{ | |
failure = failure.Failure; | |
} | |
if (failure == null) | |
{ | |
node.Failure = Root; | |
} | |
else | |
{ | |
node.Failure = failure.GetTransition(value); | |
node.AddResults(node.Failure.Results); | |
if (!node.IsWord) | |
{ | |
node.IsWord = failure.IsWord; | |
} | |
} | |
newNodes.AddRange(node.Transitions); | |
} | |
nodes = newNodes; | |
} | |
} | |
public AhoCorasickTreeSlim BuildSlim() | |
{ | |
SetOffsets(); | |
var data = new List<byte>(); | |
var queue = new Queue<AhoCorasickTreeNode>(); | |
queue.Enqueue(Root); | |
while (queue.Count > 0) | |
{ | |
var currentNode = queue.Dequeue(); | |
if (currentNode._entries.Length == 0) continue; | |
data.Add((byte)currentNode._entries.Length); | |
data.AddRange(BitConverter.GetBytes((char)currentNode.Failure.Offset)); | |
if (currentNode._entries.Length > 0) | |
{ | |
foreach (var entry in currentNode._entries) | |
{ | |
if (entry.Key != 0) | |
{ | |
queue.Enqueue(entry.Value); | |
data.Add((byte)entry.Key); | |
data.AddRange(entry.Value.IsWord | |
? BitConverter.GetBytes((char)0) | |
: BitConverter.GetBytes((char)entry.Value.Offset)); | |
} | |
else | |
{ | |
data.AddRange(new byte[SizeOfKey]); | |
data.AddRange(new byte[SizeOfNode]); | |
} | |
} | |
} | |
} | |
return new AhoCorasickTreeSlim(data.ToArray()); | |
} | |
private const int SizeOfSize = sizeof(byte); | |
private const int SizeOfFailure = sizeof(char); | |
private const int SizeOfKey = sizeof(byte); | |
private const int SizeOfNode = sizeof(char); | |
private void SetOffsets() | |
{ | |
var roots = 0; | |
var elements = 0; | |
var queue = new Queue<AhoCorasickTreeNode>(); | |
queue.Enqueue(Root); | |
while (queue.Count > 0) | |
{ | |
var currentNode = queue.Dequeue(); | |
if (currentNode._entries.Length == 0) continue; | |
currentNode.Offset = calcOffset(roots, elements); | |
roots++; | |
foreach (var entry in currentNode._entries) | |
{ | |
if (entry.Key != 0) queue.Enqueue(entry.Value); | |
elements++; | |
} | |
} | |
} | |
private int calcOffset(int roots, int childs) | |
{ | |
return roots * (SizeOfSize + SizeOfFailure) + childs * (SizeOfKey + SizeOfNode); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Generic; | |
using System.Linq; | |
namespace Adform.AdServing.AhoCorasickTree.Sandbox.V7g | |
{ | |
internal class AhoCorasickTreeNode | |
{ | |
public char Value { get; private set; } | |
public AhoCorasickTreeNode Failure { get; set; } | |
public bool IsWord; | |
private readonly List<string> _results; | |
private readonly AhoCorasickTreeNode _parent; | |
public List<string> Results { get { return _results; } } | |
public AhoCorasickTreeNode ParentFailure { get { return _parent == null ? null : _parent.Failure; } } | |
public AhoCorasickTreeNode[] Transitions { get { return _entries.Where(x => x.Key != 0).Select(x => x.Value).ToArray(); } } | |
public int Offset { get; set; } | |
internal Entry[] _entries; | |
private int _size; | |
public AhoCorasickTreeNode() : this(null, ' ') | |
{ | |
} | |
private AhoCorasickTreeNode(AhoCorasickTreeNode parent, char value) | |
{ | |
Value = value; | |
_parent = parent; | |
_results = new List<string>(); | |
_entries = new Entry[0]; | |
} | |
public void AddResult(string result) | |
{ | |
if (!_results.Contains(result)) | |
{ | |
_results.Add(result); | |
} | |
} | |
public void AddResults(IEnumerable<string> results) | |
{ | |
foreach (var result in results) | |
{ | |
AddResult(result); | |
} | |
} | |
public AhoCorasickTreeNode AddTransition(char c) | |
{ | |
var node = new AhoCorasickTreeNode(this, c); | |
if (_size == 0) Resize(); | |
while (true) | |
{ | |
var ind = c & (_size - 1); | |
if (_entries[ind].Key != 0 && _entries[ind].Key != c) | |
{ | |
Resize(); | |
continue; | |
} | |
_entries[ind].Key = c; | |
_entries[ind].Value = node; | |
return node; | |
} | |
} | |
public AhoCorasickTreeNode GetTransition(char c) | |
{ | |
if (_size == 0) return null; | |
var ind = c & (_size - 1); | |
var keyThere = _entries[ind].Key; | |
if (keyThere != 0 && (keyThere == c)) | |
{ | |
return _entries[ind].Value; | |
} | |
return null; | |
} | |
public bool ContainsTransition(char c) | |
{ | |
return GetTransition(c) != null; | |
} | |
private void Resize() | |
{ | |
var newSize = _entries.Length * 2; | |
if (newSize == 0) newSize = 1; | |
Resize(newSize); | |
} | |
private void Resize(int newSize) | |
{ | |
var newEntries = new Entry[newSize]; | |
for (var i = 0; i < _entries.Length; i++) | |
{ | |
var key = _entries[i].Key; | |
var value = _entries[i].Value; | |
var ind = key & (newSize- 1); | |
if (newEntries[ind].Key != 0 && newEntries[ind].Key != key) | |
{ | |
Resize(newSize * 2); | |
return; | |
} | |
newEntries[ind].Key = key; | |
newEntries[ind].Value = value; | |
} | |
_entries = newEntries; | |
_size = newSize; | |
} | |
} | |
internal struct Entry | |
{ | |
public char Key; | |
public AhoCorasickTreeNode Value; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Runtime.CompilerServices; | |
namespace Adform.AdServing.AhoCorasickTree.Sandbox.V7g | |
{ | |
public class AhoCorasickTreeSlim | |
{ | |
private readonly byte[] _data; | |
public AhoCorasickTreeSlim(byte[] data) | |
{ | |
_data = data; | |
} | |
public byte[] Data => _data; | |
private const int SizeOfSize = sizeof(byte); | |
private const int SizeOfFailure = sizeof(char); | |
private const int SizeOfKey = sizeof(byte); | |
private const int SizeOfNode = sizeof(char); | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
private static unsafe byte GetKey(byte* currentNodePtr, int ind) | |
{ | |
return *(byte*)(currentNodePtr + SizeOfSize + SizeOfFailure + ind * (SizeOfKey + SizeOfNode)); | |
} | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
private unsafe byte* GetNext(byte* b, byte* currentNodePtr, int ind) | |
{ | |
return b + *(char*)(currentNodePtr + (SizeOfSize + SizeOfFailure + ind * (SizeOfKey + SizeOfNode) + SizeOfKey)); | |
} | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
private unsafe byte* GetFailure(byte* b, byte* currentNodePtr) | |
{ | |
return b + *(char*)(currentNodePtr + SizeOfSize); | |
} | |
public unsafe bool Contains(string text) | |
{ | |
fixed (byte* b = _data) | |
fixed (char* p = text) | |
{ | |
var len = text.Length * 2; | |
var currentNodePtr = b; | |
var cptr = p; | |
while (len > 0) | |
{ | |
var c = *cptr; | |
cptr++; | |
len -= 2; | |
CheckFailure: | |
var size = *currentNodePtr; | |
var ind = c & (size - 1); | |
var key = GetKey(currentNodePtr, ind); | |
if (key == c) | |
{ | |
currentNodePtr = GetNext(b, currentNodePtr, ind); | |
if (currentNodePtr == b) return true; | |
} | |
else | |
{ | |
currentNodePtr = GetFailure(b, currentNodePtr); | |
if (currentNodePtr != b) goto CheckFailure; | |
} | |
} | |
} | |
return false; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment