Skip to content

Instantly share code, notes, and snippets.

@ClickerMonkey
Created October 29, 2014 15:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ClickerMonkey/289af1bab22974bb9c42 to your computer and use it in GitHub Desktop.
Save ClickerMonkey/289af1bab22974bb9c42 to your computer and use it in GitHub Desktop.
Markov chain data structure.
import java.util.Collection;
import java.util.List;
import java.util.Random;
public class Markov<T>
{
protected final MarkovChain<T> root;
protected final MarkovChain<T> starters;
public Markov()
{
root = new MarkovChain<T>( null, null );
starters = new MarkovChain<T>( null, null );
}
public void build( int depth, List<T> states )
{
final int stateCount = states.size();
starters.addNext( states.get( 0 ) );
for (int i = 0; i < stateCount; i++)
{
final int chainDepth = Math.min( depth, stateCount - i );
MarkovChain<T> chain = root;
for (int k = 0; k < chainDepth; k++)
{
chain = chain.addNext( states.get( i + k ) );
if (chainDepth != depth)
{
chain.addEnd();
}
}
}
}
public void build( int depth, T... states )
{
final int stateCount = states.length;
starters.addNext( states[0] );
for (int i = 0; i < stateCount; i++)
{
final int chainDepth = Math.min( depth, stateCount - i );
MarkovChain<T> chain = root;
for (int k = 0; k < chainDepth; k++)
{
chain = chain.addNext( states[i + k] );
if (chainDepth != depth)
{
chain.addEnd();
}
}
}
}
public double probability( T ... states )
{
root.getNext( states[0] );
return 0.0;
}
public int generateRandom( Random random, T initialState, int min, T[] out )
{
int chainRoot = 0;
int chainLength = 0;
MarkovChain<T> chain = root.getNext( initialState );
while (chainLength < out.length)
{
out[chainLength] = chain.getState();
while (chain == null || !chain.hasNext())
{
T rootState = null;
if (chainRoot == chainLength + 1)
{
rootState = getRandom( random );
}
else
{
rootState = out[++chainRoot];
}
chain = root.getNext( rootState );
for (int k = chainRoot + 1; k <= chainLength && chain != null; k++)
{
chain = chain.getNext( out[k] );
}
}
chainLength++;
if (chainLength > min && chain.isEnd() && chain.isEnd( random ))
{
break;
}
chain = chain.getRandom( random );
}
return chainLength;
}
public T[] generate( Random random, T[] out )
{
return generate( random, getRandom( random ), out );
}
public T[] generate( Random random, T initialState, T[] out )
{
int chainRoot = 0;
MarkovChain<T> chain = root.getNext( initialState );
for (int i = 0; i < out.length; i++)
{
out[i] = chain.getState();
while (chain == null || !chain.hasNext())
{
T rootState = null;
if (chainRoot == i + 1)
{
rootState = getRandom( random );
}
else
{
rootState = out[++chainRoot];
}
chain = root.getNext( rootState );
for (int k = chainRoot + 1; k <= i && chain != null; k++)
{
chain = chain.getNext( out[k] );
}
}
chain = chain.getRandom( random );
}
return out;
}
public <C extends Collection<T>> C generateRandomSize( Random random, int min, int max, C destination )
{
MarkovChain<T> start = starters.getRandom( random );
T state = start.getState();
destination.add( state );
state = getRandom( start, random );
if (start.hasNext())
{
for (int i = 0; i < max; i++)
{
destination.add( state );
MarkovChain<T> next = start.getNext( state );
if (next == null || !next.hasNext())
{
start = root;
}
else
{
start = next;
}
if (i >= min && next.isEnd() && next.isEnd( random ))
{
break;
}
state = getRandom( start, random );
}
}
return destination;
}
public T getRandom( Random random )
{
return getRandom( root, random );
}
public T getRandom( Random random, T[] previousStates, int previousStateCount )
{
MarkovChain<T> r = root;
for (int i = 0; i < previousStateCount; i++)
{
MarkovChain<T> n = r.getNext( previousStates[i] );
if (n == null || !n.hasNext())
{
break;
}
r = n;
}
return getRandom( r, random );
}
protected T getRandom( MarkovChain<T> chain, Random random )
{
MarkovChain<T> r = chain.getRandom( random );
return (r != null ? r.getState() : null);
}
protected T getRandomStart( Random random )
{
return getRandom( starters, random );
}
}
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
public class MarkovChain<T>
{
protected final T state;
protected final MarkovChain<T> parent;
protected int occurrences;
protected Map<T, MarkovChain<T>> next;
protected int nextTotalOccurrences;
protected int ends;
public MarkovChain( T state, MarkovChain<T> parent )
{
this.state = state;
this.parent = parent;
}
public T getState()
{
return state;
}
public int getOccurrences()
{
return occurrences;
}
public void addOccurrence()
{
occurrences++;
}
public MarkovChain<T> addNext( T nextValue )
{
if (next == null)
{
next = new HashMap<T, MarkovChain<T>>();
}
MarkovChain<T> chain = next.get( nextValue );
if (chain == null)
{
chain = new MarkovChain<T>( nextValue, this );
next.put( nextValue, chain );
}
chain.addOccurrence();
nextTotalOccurrences++;
return chain;
}
public MarkovChain<T> getRandom( Random random )
{
if (next == null)
{
return null;
}
int i = random.nextInt( nextTotalOccurrences );
for (MarkovChain<T> n : next.values())
{
i -= n.occurrences;
if (i <= 0)
{
return n;
}
}
return null;
}
public double getProbability( T value )
{
MarkovChain<T> n = next.get( value );
return (n == null ? 0.0 : (double)n.occurrences / (double)nextTotalOccurrences );
}
public void addEnd()
{
ends++;
}
public boolean isEnd( Random random )
{
return random.nextInt( ends + nextTotalOccurrences ) < ends;
}
public boolean isEnd()
{
return (ends > 0);
}
public MarkovChain<T> getNext( T value )
{
return next.get( value );
}
public boolean hasNext()
{
return (next != null && !next.isEmpty());
}
public boolean hasNext( T value )
{
return next.containsKey( value );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment