-
-
Save ethrbunny/4d5a6c034febbc9b3f1a27d74c954690 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.numenta.nupic.Parameters; | |
import org.numenta.nupic.algorithms.*; | |
import org.numenta.nupic.network.Network; | |
import org.numenta.nupic.network.PublisherSupplier; | |
import org.numenta.nupic.network.sensor.ObservableSensor; | |
import org.numenta.nupic.network.sensor.Sensor; | |
import org.numenta.nupic.network.sensor.SensorParams; | |
import org.numenta.nupic.util.FastRandom; | |
import org.numenta.nupic.util.Tuple; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.HashMap; | |
import java.util.Map; | |
/** | |
* Created on 5/11/17. | |
*/ | |
public class HTMLocal { | |
private static final Logger LOGGER = LoggerFactory.getLogger(HTMLocal.class); | |
// https://github.com/numenta/htm.java-examples/blob/master/src/main/java/org/numenta/nupic/examples/napi/hotgym/NetworkDemoHarness.java | |
public Network createBasicNetwork() { | |
Parameters p = getParameters().copy(); | |
p = p.union(getDayDemoTestEncoderParams()); | |
p.set(Parameters.KEY.RANDOM, new FastRandom(42)); | |
// p.set(Parameters.KEY.INFERRED_FIELDS, getInferredFieldsMap("stamp", SDRClassifier.class)); | |
// The "headers" are the titles of your comma separated fields; (could be "timestamp,consumption,location" for 3 fields) | |
// The "Data Type" of the field (see FieldMetaTypes) (could be "datetime,float,geo" for 3 field types corresponding to above) | |
Sensor<ObservableSensor<String[]>> sensor = Sensor.create( | |
ObservableSensor::create, SensorParams.create(SensorParams.Keys::obs, new Object[] {"name", | |
PublisherSupplier.builder() | |
.addHeader("stamp, asn, probeType, responseCode, result") | |
.addHeader("datetime, int, int, bool, int") | |
.addHeader("T, B") // Special flag. "B" means Blank (see Tests for other examples) | |
.build() })); | |
Network network = Network.create("test network", p).add(Network.createRegion("r1") | |
.add(Network.createLayer("1", p) | |
// .alterParameter(KEY.AUTO_CLASSIFY, true) // <--- Remove this line if doing anomalies and not predictions | |
.add(Anomaly.create()) // <--- Remove this line if doing predictions and not anomalies | |
.add(new TemporalMemory()) | |
.add(new SpatialPooler()) | |
.add(sensor))); | |
return network; | |
} | |
public Parameters getDayDemoTestEncoderParams() { | |
Map<String, Map<String, Object>> fieldEncodings = getDayDemoFieldEncodingMap(); | |
Parameters p = Parameters.getEncoderDefaultParameters(); | |
p.set(Parameters.KEY.FIELD_ENCODING_MAP, fieldEncodings); | |
return p; | |
} | |
public Map<String, Map<String, Object>> getDayDemoFieldEncodingMap() { | |
Map<String, Map<String, Object>> fieldEncodings = setupMap( | |
null, | |
0, // n | |
0, // w | |
0, 0, 0, 0, null, null, null, | |
"stamp", "datetime", "DateEncoder"); | |
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_PATTERN.getFieldName(), "YYYY-MM-dd HH:mm:ss"); | |
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_DOFW.getFieldName(), new Tuple(1, 1.0)); // Day of week | |
fieldEncodings.get("stamp").put(Parameters.KEY.DATEFIELD_TOFD.getFieldName(), new Tuple(21, 9.5)); // Time of day | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 99999.0, -1, 1, null, null, null, | |
"asn", "number", "RandomDistributedScalarEncoder"); | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 20, -1, 1, null, null, null, | |
"probeType", "number", "ScalarEncoder"); | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 5.0, -1, 1, null, null, null, | |
"responseCode", "number", "ScalarEncoder"); | |
fieldEncodings = setupMap( | |
fieldEncodings, | |
500, // n | |
21, // w | |
0.0, 99999.0, -1, 1, null, null, null, | |
"results", "number", "RandomDistributedScalarEncoder"); | |
return fieldEncodings; | |
} | |
public Map<String, Map<String, Object>> setupMap( | |
Map<String, Map<String, Object>> map, | |
int n, int w, double min, double max, double radius, double resolution, Boolean periodic, | |
Boolean clip, Boolean forced, String fieldName, String fieldType, String encoderType) { | |
if(map == null) { | |
map = new HashMap<>(); | |
} | |
Map<String, Object> inner; | |
if((inner = map.get(fieldName)) == null) { | |
map.put(fieldName, inner = new HashMap<>()); | |
} | |
inner.put("n", n); | |
inner.put("w", w); | |
inner.put("minVal", min); | |
inner.put("maxVal", max); | |
if(radius >= 0 ) { | |
inner.put("radius", radius); | |
} | |
if(resolution >= 0 ) { | |
inner.put("resolution", resolution); | |
} | |
if(periodic != null) inner.put("periodic", periodic); | |
if(clip != null) inner.put("clipInput", clip); | |
if(forced != null) inner.put("forced", forced); | |
if(fieldName != null) inner.put("fieldName", fieldName); | |
if(fieldType != null) inner.put("fieldType", fieldType); | |
if(encoderType != null) inner.put("encoderType", encoderType); | |
return map; | |
} | |
public Parameters getParameters() { | |
Parameters parameters = Parameters.getAllDefaultParameters(); | |
parameters.set(Parameters.KEY.INPUT_DIMENSIONS, new int[] { 8 }); | |
parameters.set(Parameters.KEY.COLUMN_DIMENSIONS, new int[] { 20 }); | |
parameters.set(Parameters.KEY.CELLS_PER_COLUMN, 6); | |
//SpatialPooler specific | |
parameters.set(Parameters.KEY.POTENTIAL_RADIUS, 12);//3 | |
parameters.set(Parameters.KEY.POTENTIAL_PCT, 0.5);//0.5 | |
parameters.set(Parameters.KEY.GLOBAL_INHIBITION, false); | |
parameters.set(Parameters.KEY.LOCAL_AREA_DENSITY, -1.0); | |
parameters.set(Parameters.KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0); | |
parameters.set(Parameters.KEY.STIMULUS_THRESHOLD, 1.0); | |
parameters.set(Parameters.KEY.SYN_PERM_INACTIVE_DEC, 0.01); | |
parameters.set(Parameters.KEY.SYN_PERM_ACTIVE_INC, 0.1); | |
parameters.set(Parameters.KEY.SYN_PERM_TRIM_THRESHOLD, 0.05); | |
parameters.set(Parameters.KEY.SYN_PERM_CONNECTED, 0.1); | |
parameters.set(Parameters.KEY.MIN_PCT_OVERLAP_DUTY_CYCLES, 0.1); | |
parameters.set(Parameters.KEY.MIN_PCT_ACTIVE_DUTY_CYCLES, 0.1); | |
parameters.set(Parameters.KEY.DUTY_CYCLE_PERIOD, 10); | |
parameters.set(Parameters.KEY.MAX_BOOST, 10.0); | |
parameters.set(Parameters.KEY.SEED, 42); | |
//Temporal Memory specific | |
parameters.set(Parameters.KEY.INITIAL_PERMANENCE, 0.2); | |
parameters.set(Parameters.KEY.CONNECTED_PERMANENCE, 0.8); | |
parameters.set(Parameters.KEY.MIN_THRESHOLD, 5); | |
parameters.set(Parameters.KEY.MAX_NEW_SYNAPSE_COUNT, 6); | |
parameters.set(Parameters.KEY.PERMANENCE_INCREMENT, 0.05); | |
parameters.set(Parameters.KEY.PERMANENCE_DECREMENT, 0.05); | |
parameters.set(Parameters.KEY.ACTIVATION_THRESHOLD, 4); | |
return parameters; | |
} | |
public Map<String, Class<? extends Classifier>> getInferredFieldsMap( | |
String field, Class<? extends Classifier> classifier) { | |
Map<String, Class<? extends Classifier>> inferredFieldsMap = new HashMap<>(); | |
inferredFieldsMap.put(field, classifier); | |
return inferredFieldsMap; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment