Skip to content

Instantly share code, notes, and snippets.

@ethrbunny
Last active May 11, 2017 23:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ethrbunny/4d5a6c034febbc9b3f1a27d74c954690 to your computer and use it in GitHub Desktop.
Save ethrbunny/4d5a6c034febbc9b3f1a27d74c954690 to your computer and use it in GitHub Desktop.
import org.numenta.nupic.Parameters;
import org.numenta.nupic.algorithms.*;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.network.PublisherSupplier;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.util.FastRandom;
import org.numenta.nupic.util.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
/**
* Created on 5/11/17.
*/
public class HTMLocal {
private static final Logger LOGGER = LoggerFactory.getLogger(HTMLocal.class);
// https://github.com/numenta/htm.java-examples/blob/master/src/main/java/org/numenta/nupic/examples/napi/hotgym/NetworkDemoHarness.java
public Network createBasicNetwork() {
Parameters p = getParameters().copy();
p = p.union(getDayDemoTestEncoderParams());
p.set(Parameters.KEY.RANDOM, new FastRandom(42));
// p.set(Parameters.KEY.INFERRED_FIELDS, getInferredFieldsMap("stamp", SDRClassifier.class));
// The "headers" are the titles of your comma separated fields; (could be "timestamp,consumption,location" for 3 fields)
// The "Data Type" of the field (see FieldMetaTypes) (could be "datetime,float,geo" for 3 field types corresponding to above)
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(SensorParams.Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("stamp, asn, probeType, responseCode, result")
.addHeader("datetime, int, int, bool, int")
.addHeader("T, B") // Special flag. "B" means Blank (see Tests for other examples)
.build() }));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
// .alterParameter(KEY.AUTO_CLASSIFY, true) // <--- Remove this line if doing anomalies and not predictions
.add(Anomaly.create()) // <--- Remove this line if doing predictions and not anomalies
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));
return network;
}
public Parameters getDayDemoTestEncoderParams() {
Map<String, Map<String, Object>> fieldEncodings = getDayDemoFieldEncodingMap();
Parameters p = Parameters.getEncoderDefaultParameters();
p.set(Parameters.KEY.FIELD_ENCODING_MAP, fieldEncodings);
return p;
}
public Map<String, Map<String, Object>> getDayDemoFieldEncodingMap() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
0, // n
0, // w
0, 0, 0, 0, null, null, null,
"stamp", "datetime", "DateEncoder");
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_PATTERN.getFieldName(), "YYYY-MM-dd HH:mm:ss");
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(21, 9.5)); // Time of day
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 99999.0, -1, 1, null, null, null,
"asn", "number", "RandomDistributedScalarEncoder");
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 20, -1, 1, null, null, null,
"probeType", "number", "ScalarEncoder");
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 5.0, -1, 1, null, null, null,
"responseCode", "number", "ScalarEncoder");
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 99999.0, -1, 1, null, null, null,
"results", "number", "RandomDistributedScalarEncoder");
return fieldEncodings;
}
public Map<String, Map<String, Object>> setupMap(
Map<String, Map<String, Object>> map,
int n, int w, double min, double max, double radius, double resolution, Boolean periodic,
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) {
if(map == null) {
map = new HashMap<>();
}
Map<String, Object> inner;
if((inner = map.get(fieldName)) == null) {
map.put(fieldName, inner = new HashMap<>());
}
inner.put("n", n);
inner.put("w", w);
inner.put("minVal", min);
inner.put("maxVal", max);
if(radius >= 0 ) {
inner.put("radius", radius);
}
if(resolution >= 0 ) {
inner.put("resolution", resolution);
}
if(periodic != null) inner.put("periodic", periodic);
if(clip != null) inner.put("clipInput", clip);
if(forced != null) inner.put("forced", forced);
if(fieldName != null) inner.put("fieldName", fieldName);
if(fieldType != null) inner.put("fieldType", fieldType);
if(encoderType != null) inner.put("encoderType", encoderType);
return map;
}
public Parameters getParameters() {
Parameters parameters = Parameters.getAllDefaultParameters();
parameters.set(Parameters.KEY.INPUT_DIMENSIONS, new int[] { 8 });
parameters.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 20 });
parameters.set(Parameters.KEY.CELLS_PER_COLUMN, 6);
//SpatialPooler specific
parameters.set(Parameters.KEY.POTENTIAL_RADIUS, 12);//3
parameters.set(Parameters.KEY.POTENTIAL_PCT, 0.5);//0.5
parameters.set(Parameters.KEY.GLOBAL_INHIBITION, false);
parameters.set(Parameters.KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.set(Parameters.KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
parameters.set(Parameters.KEY.STIMULUS_THRESHOLD, 1.0);
parameters.set(Parameters.KEY.SYN_PERM_INACTIVE_DEC, 0.01);
parameters.set(Parameters.KEY.SYN_PERM_ACTIVE_INC, 0.1);
parameters.set(Parameters.KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
parameters.set(Parameters.KEY.SYN_PERM_CONNECTED, 0.1);
parameters.set(Parameters.KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1);
parameters.set(Parameters.KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1);
parameters.set(Parameters.KEY.DUTY_CYCLE_PERIOD, 10);
parameters.set(Parameters.KEY.MAX_BOOST, 10.0);
parameters.set(Parameters.KEY.SEED, 42);
//Temporal Memory specific
parameters.set(Parameters.KEY.INITIAL_PERMANENCE, 0.2);
parameters.set(Parameters.KEY.CONNECTED_PERMANENCE, 0.8);
parameters.set(Parameters.KEY.MIN_THRESHOLD, 5);
parameters.set(Parameters.KEY.MAX_NEW_SYNAPSE_COUNT, 6);
parameters.set(Parameters.KEY.PERMANENCE_INCREMENT, 0.05);
parameters.set(Parameters.KEY.PERMANENCE_DECREMENT, 0.05);
parameters.set(Parameters.KEY.ACTIVATION_THRESHOLD, 4);
return parameters;
}
public Map<String, Class<? extends Classifier>> getInferredFieldsMap(
String field, Class<? extends Classifier> classifier) {
Map<String, Class<? extends Classifier>> inferredFieldsMap = new HashMap<>();
inferredFieldsMap.put(field, classifier);
return inferredFieldsMap;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment