Skip to content

Instantly share code, notes, and snippets.

@cogmission
Last active May 12, 2017 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cogmission/652cb8d9a84e7686e3329758f00c15a7 to your computer and use it in GitHub Desktop.
Save cogmission/652cb8d9a84e7686e3329758f00c15a7 to your computer and use it in GitHub Desktop.
HTMLocal for phil_d_cat
package test;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.Classifier;
import org.numenta.nupic.algorithms.SDRClassifier;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.model.Cell;
import org.numenta.nupic.model.SDR;
import org.numenta.nupic.network.Inference;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.network.PublisherSupplier;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.util.FastRandom;
import org.numenta.nupic.util.Tuple;
import rx.Observer;
/**
* Created on 5/11/17.
*/
public class HTMLocal {
//private static final Logger LOGGER = LoggerFactory.getLogger(HTMLocal.class);
// https://github.com/numenta/htm.java-examples/blob/master/src/main/java/org/numenta/nupic/examples/napi/hotgym/NetworkDemoHarness.java
public Network createBasicNetwork() {
Parameters p = getParameters().copy();
p = p.union(getDayDemoTestEncoderParams());
p.set(Parameters.KEY.RANDOM, new FastRandom(42));
p.set(Parameters.KEY.INFERRED_FIELDS, getInferredFieldsMap("stamp", SDRClassifier.class));
// The "headers" are the titles of your comma separated fields; (could be "timestamp,consumption,location" for 3 fields)
// The "Data Type" of the field (see FieldMetaTypes) (could be "datetime,float,geo" for 3 field types corresponding to above)
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(SensorParams.Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("stamp, asn, probeType, responseCode, result")
.addHeader("datetime, int, int, int, int")
.addHeader("T, B") // Special flag. "B" means Blank (see Tests for other examples)
.build() }));
Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
// .alterParameter(KEY.AUTO_CLASSIFY, true) // <--- Remove this line if doing anomalies and not predictions
.add(Anomaly.create()) // <--- Remove this line if doing predictions and not anomalies
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));
return network;
}
public Parameters getDayDemoTestEncoderParams() {
Map<String, Map<String, Object>> fieldEncodings = getDayDemoFieldEncodingMap();
Parameters p = Parameters.getEncoderDefaultParameters();
p.set(Parameters.KEY.FIELD_ENCODING_MAP, fieldEncodings);
return p;
}
public Map<String, Map<String, Object>> getDayDemoFieldEncodingMap() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null,
0, // n
0, // w
0, 0, 0, 0, null, null, null,
"stamp", "datetime", "DateEncoder");
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_PATTERN.getFieldName(), "YYYY-MM-dd HH:mm:ss");
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(21, 9.5)); // Time of day
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 99999.0, -1, 1, null, null, null,
"asn", "number", "RandomDistributedScalarEncoder");
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 20, -1, 1, null, null, null,
"probeType", "number", "ScalarEncoder");
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0, 1, -1, 1, null, null, null,
"responseCode", "int", "ScalarEncoder");
fieldEncodings = setupMap(
fieldEncodings,
500, // n
21, // w
0.0, 99999.0, -1, 1, null, null, null,
"result", "number", "RandomDistributedScalarEncoder");
return fieldEncodings;
}
public Map<String, Map<String, Object>> setupMap(
Map<String, Map<String, Object>> map,
int n, int w, double min, double max, double radius, double resolution, Boolean periodic,
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) {
if(map == null) {
map = new HashMap<>();
}
Map<String, Object> inner;
if((inner = map.get(fieldName)) == null) {
map.put(fieldName, inner = new HashMap<>());
}
inner.put("n", n);
inner.put("w", w);
inner.put("minVal", min);
inner.put("maxVal", max);
if(radius >= 0 ) {
inner.put("radius", radius);
}
if(resolution >= 0 ) {
inner.put("resolution", resolution);
}
if(periodic != null) inner.put("periodic", periodic);
if(clip != null) inner.put("clipInput", clip);
if(forced != null) inner.put("forced", forced);
if(fieldName != null) inner.put("fieldName", fieldName);
if(fieldType != null) inner.put("fieldType", fieldType);
if(encoderType != null) inner.put("encoderType", encoderType);
return map;
}
public Parameters getParameters() {
Parameters parameters = Parameters.getAllDefaultParameters();
parameters.set(Parameters.KEY.INPUT_DIMENSIONS, new int[] { 8 });
parameters.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 20 });
parameters.set(Parameters.KEY.CELLS_PER_COLUMN, 6);
//SpatialPooler specific
parameters.set(Parameters.KEY.POTENTIAL_RADIUS, 12);//3
parameters.set(Parameters.KEY.POTENTIAL_PCT, 0.5);//0.5
parameters.set(Parameters.KEY.GLOBAL_INHIBITION, false);
parameters.set(Parameters.KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.set(Parameters.KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
parameters.set(Parameters.KEY.STIMULUS_THRESHOLD, 1.0);
parameters.set(Parameters.KEY.SYN_PERM_INACTIVE_DEC, 0.01);
parameters.set(Parameters.KEY.SYN_PERM_ACTIVE_INC, 0.1);
parameters.set(Parameters.KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
parameters.set(Parameters.KEY.SYN_PERM_CONNECTED, 0.1);
parameters.set(Parameters.KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1);
parameters.set(Parameters.KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1);
parameters.set(Parameters.KEY.DUTY_CYCLE_PERIOD, 10);
parameters.set(Parameters.KEY.MAX_BOOST, 10.0);
parameters.set(Parameters.KEY.SEED, 42);
//Temporal Memory specific
parameters.set(Parameters.KEY.INITIAL_PERMANENCE, 0.2);
parameters.set(Parameters.KEY.CONNECTED_PERMANENCE, 0.8);
parameters.set(Parameters.KEY.MIN_THRESHOLD, 5);
parameters.set(Parameters.KEY.MAX_NEW_SYNAPSE_COUNT, 6);
parameters.set(Parameters.KEY.PERMANENCE_INCREMENT, 0.05);
parameters.set(Parameters.KEY.PERMANENCE_DECREMENT, 0.05);
parameters.set(Parameters.KEY.ACTIVATION_THRESHOLD, 4);
return parameters;
}
public Map<String, Class<? extends Classifier>> getInferredFieldsMap(
String field, Class<? extends Classifier> classifier) {
Map<String, Class<? extends Classifier>> inferredFieldsMap = new HashMap<>();
inferredFieldsMap.put(field, classifier);
return inferredFieldsMap;
}
public static void main(String[] args) {
HTMLocal local = new HTMLocal();
Network network = local.createBasicNetwork();
int cellsPerColumn = (int)network.getParameters().get(KEY.CELLS_PER_COLUMN);
//Subscribe
network.observe().subscribe(new Observer<Inference>() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override
public void onNext(Inference inf) {
String layerInput = inf.getLayerInput().toString();
System.out.println("--------------------------------------------------------");
System.out.println("Iteration: " + inf.getRecordNum());
System.out.println("===== Sequence Num: " + (inf.getRecordNum() + 1) + " =====");
System.out.println("layerInput Input = " + layerInput);
System.out.println("SpatialPooler Output = " + Arrays.toString(inf.getFeedForwardActiveColumns()));
int[] predictedColumns = SDR.cellsAsColumnIndices(inf.getPredictiveCells(), cellsPerColumn); //Get the predicted column indexes
System.out.println("TemporalMemory Input = " + Arrays.toString(inf.getFeedForwardSparseActives()));
System.out.println("TemporalMemory Prediction = " + Arrays.toString(predictedColumns));
Set<Cell> actives = inf.getActiveCells();
int[] actCellIndices = SDR.asCellIndices(actives);
System.out.println("TemporalMemory Active Cells = " + Arrays.toString(actCellIndices));
Set<Cell> pred = inf.getPredictiveCells();
int[] predCellIndices = SDR.asCellIndices(pred);
System.out.println("TemporalMemory Predictive Cells = " + Arrays.toString(predCellIndices));
//Anomaly
double score = inf.getAnomalyScore();
System.out.println("Anomaly Score = " + score);
}
});
// Get the Publisher
Publisher publisher = network.getPublisher();
// Start the Network
network.start();
// Feed in the data
//.addHeader("stamp, asn, probeType, responseCode, result")
//.addHeader("datetime, int, int, bool, int")
publisher.onNext("2017-05-11 06:46:21,4,5.0,1,6");
publisher.onComplete();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment