Last active
May 12, 2017 14:12
-
-
Save cogmission/652cb8d9a84e7686e3329758f00c15a7 to your computer and use it in GitHub Desktop.
HTMLocal for phil_d_cat
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package test; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.Map; | |
import java.util.Set; | |
import org.numenta.nupic.Parameters; | |
import org.numenta.nupic.Parameters.KEY; | |
import org.numenta.nupic.algorithms.Anomaly; | |
import org.numenta.nupic.algorithms.Classifier; | |
import org.numenta.nupic.algorithms.SDRClassifier; | |
import org.numenta.nupic.algorithms.SpatialPooler; | |
import org.numenta.nupic.algorithms.TemporalMemory; | |
import org.numenta.nupic.model.Cell; | |
import org.numenta.nupic.model.SDR; | |
import org.numenta.nupic.network.Inference; | |
import org.numenta.nupic.network.Network; | |
import org.numenta.nupic.network.PublisherSupplier; | |
import org.numenta.nupic.network.sensor.ObservableSensor; | |
import org.numenta.nupic.network.sensor.Publisher; | |
import org.numenta.nupic.network.sensor.Sensor; | |
import org.numenta.nupic.network.sensor.SensorParams; | |
import org.numenta.nupic.util.FastRandom; | |
import org.numenta.nupic.util.Tuple; | |
import rx.Observer; | |
/** | |
* Created on 5/11/17. | |
*/ | |
public class HTMLocal { | |
//private static final Logger LOGGER = LoggerFactory.getLogger(HTMLocal.class); | |
// https://github.com/numenta/htm.java-examples/blob/master/src/main/java/org/numenta/nupic/examples/napi/hotgym/NetworkDemoHarness.java | |
public Network createBasicNetwork() { | |
Parameters p = getParameters().copy(); | |
p = p.union(getDayDemoTestEncoderParams()); | |
p.set(Parameters.KEY.RANDOM, new FastRandom(42)); | |
p.set(Parameters.KEY.INFERRED_FIELDS, getInferredFieldsMap("stamp", SDRClassifier.class)); | |
// The "headers" are the titles of your comma separated fields; (could be "timestamp,consumption,location" for 3 fields) | |
// The "Data Type" of the field (see FieldMetaTypes) (could be "datetime,float,geo" for 3 field types corresponding to above) | |
Sensor<ObservableSensor<String[]>> sensor = Sensor.create( | |
ObservableSensor::create, SensorParams.create(SensorParams.Keys::obs, new Object[] {"name", | |
PublisherSupplier.builder() | |
.addHeader("stamp, asn, probeType, responseCode, result") | |
.addHeader("datetime, int, int, int, int") | |
.addHeader("T, B") // Special flag. "B" means Blank (see Tests for other examples) | |
.build() })); | |
Network network = Network.create("test network", p).add(Network.createRegion("r1") | |
.add(Network.createLayer("1", p) | |
// .alterParameter(KEY.AUTO_CLASSIFY, true) // <--- Remove this line if doing anomalies and not predictions | |
.add(Anomaly.create()) // <--- Remove this line if doing predictions and not anomalies | |
.add(new TemporalMemory()) | |
.add(new SpatialPooler()) | |
.add(sensor))); | |
return network; | |
} | |
public Parameters getDayDemoTestEncoderParams() { | |
Map<String, Map<String, Object>> fieldEncodings = getDayDemoFieldEncodingMap(); | |
Parameters p = Parameters.getEncoderDefaultParameters(); | |
p.set(Parameters.KEY.FIELD_ENCODING_MAP, fieldEncodings); | |
return p; | |
} | |
public Map<String, Map<String, Object>> getDayDemoFieldEncodingMap() { | |
Map<String, Map<String, Object>> fieldEncodings = setupMap( | |
null, | |
0, // n | |
0, // w | |
0, 0, 0, 0, null, null, null, | |
"stamp", "datetime", "DateEncoder"); | |
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_PATTERN.getFieldName(), "YYYY-MM-dd HH:mm:ss"); | |
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week | |
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(21, 9.5)); // Time of day | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 99999.0, -1, 1, null, null, null, | |
"asn", "number", "RandomDistributedScalarEncoder"); | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 20, -1, 1, null, null, null, | |
"probeType", "number", "ScalarEncoder"); | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0, 1, -1, 1, null, null, null, | |
"responseCode", "int", "ScalarEncoder"); | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 99999.0, -1, 1, null, null, null, | |
"result", "number", "RandomDistributedScalarEncoder"); | |
return fieldEncodings; | |
} | |
public Map<String, Map<String, Object>> setupMap( | |
Map<String, Map<String, Object>> map, | |
int n, int w, double min, double max, double radius, double resolution, Boolean periodic, | |
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) { | |
if(map == null) { | |
map = new HashMap<>(); | |
} | |
Map<String, Object> inner; | |
if((inner = map.get(fieldName)) == null) { | |
map.put(fieldName, inner = new HashMap<>()); | |
} | |
inner.put("n", n); | |
inner.put("w", w); | |
inner.put("minVal", min); | |
inner.put("maxVal", max); | |
if(radius >= 0 ) { | |
inner.put("radius", radius); | |
} | |
if(resolution >= 0 ) { | |
inner.put("resolution", resolution); | |
} | |
if(periodic != null) inner.put("periodic", periodic); | |
if(clip != null) inner.put("clipInput", clip); | |
if(forced != null) inner.put("forced", forced); | |
if(fieldName != null) inner.put("fieldName", fieldName); | |
if(fieldType != null) inner.put("fieldType", fieldType); | |
if(encoderType != null) inner.put("encoderType", encoderType); | |
return map; | |
} | |
public Parameters getParameters() { | |
Parameters parameters = Parameters.getAllDefaultParameters(); | |
parameters.set(Parameters.KEY.INPUT_DIMENSIONS, new int[] { 8 }); | |
parameters.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 20 }); | |
parameters.set(Parameters.KEY.CELLS_PER_COLUMN, 6); | |
//SpatialPooler specific | |
parameters.set(Parameters.KEY.POTENTIAL_RADIUS, 12);//3 | |
parameters.set(Parameters.KEY.POTENTIAL_PCT, 0.5);//0.5 | |
parameters.set(Parameters.KEY.GLOBAL_INHIBITION, false); | |
parameters.set(Parameters.KEY.LOCAL_AREA_DENSITY, -1.0); | |
parameters.set(Parameters.KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0); | |
parameters.set(Parameters.KEY.STIMULUS_THRESHOLD, 1.0); | |
parameters.set(Parameters.KEY.SYN_PERM_INACTIVE_DEC, 0.01); | |
parameters.set(Parameters.KEY.SYN_PERM_ACTIVE_INC, 0.1); | |
parameters.set(Parameters.KEY.SYN_PERM_TRIM_THRESHOLD, 0.05); | |
parameters.set(Parameters.KEY.SYN_PERM_CONNECTED, 0.1); | |
parameters.set(Parameters.KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1); | |
parameters.set(Parameters.KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1); | |
parameters.set(Parameters.KEY.DUTY_CYCLE_PERIOD, 10); | |
parameters.set(Parameters.KEY.MAX_BOOST, 10.0); | |
parameters.set(Parameters.KEY.SEED, 42); | |
//Temporal Memory specific | |
parameters.set(Parameters.KEY.INITIAL_PERMANENCE, 0.2); | |
parameters.set(Parameters.KEY.CONNECTED_PERMANENCE, 0.8); | |
parameters.set(Parameters.KEY.MIN_THRESHOLD, 5); | |
parameters.set(Parameters.KEY.MAX_NEW_SYNAPSE_COUNT, 6); | |
parameters.set(Parameters.KEY.PERMANENCE_INCREMENT, 0.05); | |
parameters.set(Parameters.KEY.PERMANENCE_DECREMENT, 0.05); | |
parameters.set(Parameters.KEY.ACTIVATION_THRESHOLD, 4); | |
return parameters; | |
} | |
public Map<String, Class<? extends Classifier>> getInferredFieldsMap( | |
String field, Class<? extends Classifier> classifier) { | |
Map<String, Class<? extends Classifier>> inferredFieldsMap = new HashMap<>(); | |
inferredFieldsMap.put(field, classifier); | |
return inferredFieldsMap; | |
} | |
public static void main(String[] args) { | |
HTMLocal local = new HTMLocal(); | |
Network network = local.createBasicNetwork(); | |
int cellsPerColumn = (int)network.getParameters().get(KEY.CELLS_PER_COLUMN); | |
//Subscribe | |
network.observe().subscribe(new Observer<Inference>() { | |
@Override public void onCompleted() {} | |
@Override public void onError(Throwable e) { e.printStackTrace(); } | |
@Override | |
public void onNext(Inference inf) { | |
String layerInput = inf.getLayerInput().toString(); | |
System.out.println("--------------------------------------------------------"); | |
System.out.println("Iteration: " + inf.getRecordNum()); | |
System.out.println("===== Sequence Num: " + (inf.getRecordNum() + 1) + " ====="); | |
System.out.println("layerInput Input = " + layerInput); | |
System.out.println("SpatialPooler Output = " + Arrays.toString(inf.getFeedForwardActiveColumns())); | |
int[] predictedColumns = SDR.cellsAsColumnIndices(inf.getPredictiveCells(), cellsPerColumn); //Get the predicted column indexes | |
System.out.println("TemporalMemory Input = " + Arrays.toString(inf.getFeedForwardSparseActives())); | |
System.out.println("TemporalMemory Prediction = " + Arrays.toString(predictedColumns)); | |
Set<Cell> actives = inf.getActiveCells(); | |
int[] actCellIndices = SDR.asCellIndices(actives); | |
System.out.println("TemporalMemory Active Cells = " + Arrays.toString(actCellIndices)); | |
Set<Cell> pred = inf.getPredictiveCells(); | |
int[] predCellIndices = SDR.asCellIndices(pred); | |
System.out.println("TemporalMemory Predictive Cells = " + Arrays.toString(predCellIndices)); | |
//Anomaly | |
double score = inf.getAnomalyScore(); | |
System.out.println("Anomaly Score = " + score); | |
} | |
}); | |
// Get the Publisher | |
Publisher publisher = network.getPublisher(); | |
// Start the Network | |
network.start(); | |
// Feed in the data | |
//.addHeader("stamp, asn, probeType, responseCode, result") | |
//.addHeader("datetime, int, int, bool, int") | |
publisher.onNext("2017-05-11 06:46:21,4,5.0,1,6"); | |
publisher.onComplete(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment