import static org.ojalgo.ann.ArtificialNeuralNetwork.Activator.RELU; import static org.ojalgo.ann.ArtificialNeuralNetwork.Activator.SOFTMAX; import static org.ojalgo.ann.ArtificialNeuralNetwork.Error.CROSS_ENTROPY; import static org.ojalgo.ann.ArtificialNeuralNetwork.Error.HALF_SQUARED_DIFFERENCE; import java.io.File; import java.util.Arrays; import org.ojalgo.OjAlgoUtils; import org.ojalgo.ann.ArtificialNeuralNetwork; import org.ojalgo.ann.NetworkBuilder; import org.ojalgo.ann.NetworkInvoker; import org.ojalgo.ann.NetworkTrainer; import org.ojalgo.matrix.store.R032Store; import org.ojalgo.matrix.store.RawStore; import org.ojalgo.netio.BasicLogger; import org.ojalgo.structure.Access1D; /** * With v48.3 the neural network code was refactored. Here's an outline of what's new or changed. (This is NOT * a runnable program.) * * @see https://www.ojalgo.org/2020/09/neural-network-new-features-in-v48-3/ */ public class NeuralNetworkNewsAndChanges_48_3 { public static void main(final String[] args) { BasicLogger.debug(); BasicLogger.debug(NeuralNetworkNewsAndChanges_48_3.class); BasicLogger.debug(OjAlgoUtils.getTitle()); BasicLogger.debug(OjAlgoUtils.getDate()); BasicLogger.debug(); /* * Networks are now built using a per layer builder. At first you specify the network's number of * input nodes. Then you add layers, one at the time, specifying their number of output nodes and the * activator function. The output of one layer is naturally the input to the following. This * particular network has 4 input and 2 output nodes. Some would say this is a 3-layer network, but * there are only 2 calculation layers. The first calculation layer has 4 input and 6 output nodes. * The second, and final, layer has 6 input and 2 output nodes. */ ArtificialNeuralNetwork network = ArtificialNeuralNetwork.builder(4).layer(6, RELU).layer(2, SOFTMAX).get(); /* * Optionally it is possible to specify which matrix factory to use internally. This will allow * switching between double and float elements as well as different matrix representations. */ NetworkBuilder builder = ArtificialNeuralNetwork.builder(R032Store.FACTORY, 4); /* * To train a network you obtain a trainer... */ NetworkTrainer trainer = network.newTrainer(); /* * That network trainer is ready to be used, but it can be reconfigured. The trainer can be configured * with a learning rate as well as optional use of, droputs, L1 lasso regularisation and/or L2 ridge * regularisation. */ trainer.rate(0.05).dropouts().lasso(0.01).ridge(0.001); /* * The input and output can be typed as ojAlgo's most basic data type – Access1D. Just about anything * in ojAlgo "is" an Access1D. If you have arrays or lists of numbers then you can wrap them in * Access1D instances to avoid copying. Most naturally you work with ojAlgo data structures from the * beginning. */ Access1D<Double> input = Access1D.wrap(1, 2, 3, 4); Access1D<Double> output = Access1D.wrap(Arrays.asList(10.0, 20.0)); /* * Repeatedly call this, with different examples, to train the neural network. */ trainer.train(input, output); /* * To use/invoke a network you obtain an invoker... A key feature here is that you can have several * invoker instances using the same underlying network simultaneously. The invocation specific state * is in the invoker instance. */ NetworkInvoker invoker1 = network.newInvoker(); NetworkInvoker invoker2 = network.newInvoker(); output = invoker1.invoke(input); /* * Trained networks can be saved to file, and then used later */ File file = new File("somePath"); network.writeTo(file); ArtificialNeuralNetwork network2 = ArtificialNeuralNetwork.from(file); /* * It's also possible to specify a (different) matrix factory when reading a network from file. */ ArtificialNeuralNetwork network3 = ArtificialNeuralNetwork.from(RawStore.FACTORY, file); /* * What about specifying the error/loss function when training? ojAlgo supports 2 different error * functions, and which to use is dictated by the activator of the final layer. CROSS_ENTROPY has to * be used with the SOFTMAX activator, and cannot be used with any other. The correct error function * is set for you. You can manually set it, but if you set the incorrect one you'll get an exception. */ trainer.error(CROSS_ENTROPY); trainer.error(HALF_SQUARED_DIFFERENCE); } }