import static org.ojalgo.ann.ArtificialNeuralNetwork.Activator.RELU;
import static org.ojalgo.ann.ArtificialNeuralNetwork.Activator.SOFTMAX;
import static org.ojalgo.ann.ArtificialNeuralNetwork.Error.CROSS_ENTROPY;
import static org.ojalgo.ann.ArtificialNeuralNetwork.Error.HALF_SQUARED_DIFFERENCE;

import java.io.File;
import java.util.Arrays;

import org.ojalgo.OjAlgoUtils;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.ann.NetworkBuilder;
import org.ojalgo.ann.NetworkInvoker;
import org.ojalgo.ann.NetworkTrainer;
import org.ojalgo.matrix.store.R032Store;
import org.ojalgo.matrix.store.RawStore;
import org.ojalgo.netio.BasicLogger;
import org.ojalgo.structure.Access1D;

/**
 * With v48.3 the neural network code was refactored. Here's an outline of what's new or changed. (This is NOT
 * a runnable program.)
 *
 * @see https://www.ojalgo.org/2020/09/neural-network-new-features-in-v48-3/
 */
public class NeuralNetworkNewsAndChanges_48_3 {

    public static void main(final String[] args) {

        BasicLogger.debug();
        BasicLogger.debug(NeuralNetworkNewsAndChanges_48_3.class);
        BasicLogger.debug(OjAlgoUtils.getTitle());
        BasicLogger.debug(OjAlgoUtils.getDate());
        BasicLogger.debug();

        /*
         * Networks are now built using a per layer builder. At first you specify the network's number of
         * input nodes. Then you add layers, one at the time, specifying their number of output nodes and the
         * activator function. The output of one layer is naturally the input to the following. This
         * particular network has 4 input and 2 output nodes. Some would say this is a 3-layer network, but
         * there are only 2 calculation layers. The first calculation layer has 4 input and 6 output nodes.
         * The second, and final, layer has 6 input and 2 output nodes.
         */
        ArtificialNeuralNetwork network = ArtificialNeuralNetwork.builder(4).layer(6, RELU).layer(2, SOFTMAX).get();

        /*
         * Optionally it is possible to specify which matrix factory to use internally. This will allow
         * switching between double and float elements as well as different matrix representations.
         */
        NetworkBuilder builder = ArtificialNeuralNetwork.builder(R032Store.FACTORY, 4);

        /*
         * To train a network you obtain a trainer...
         */
        NetworkTrainer trainer = network.newTrainer();

        /*
         * That network trainer is ready to be used, but it can be reconfigured. The trainer can be configured
         * with a learning rate as well as optional use of, droputs, L1 lasso regularisation and/or L2 ridge
         * regularisation.
         */
        trainer.rate(0.05).dropouts().lasso(0.01).ridge(0.001);

        /*
         * The input and output can be typed as ojAlgo's most basic data type – Access1D. Just about anything
         * in ojAlgo "is" an Access1D. If you have arrays or lists of numbers then you can wrap them in
         * Access1D instances to avoid copying. Most naturally you work with ojAlgo data structures from the
         * beginning.
         */
        Access1D<Double> input = Access1D.wrap(1, 2, 3, 4);
        Access1D<Double> output = Access1D.wrap(Arrays.asList(10.0, 20.0));

        /*
         * Repeatedly call this, with different examples, to train the neural network.
         */
        trainer.train(input, output);

        /*
         * To use/invoke a network you obtain an invoker... A key feature here is that you can have several
         * invoker instances using the same underlying network simultaneously. The invocation specific state
         * is in the invoker instance.
         */
        NetworkInvoker invoker1 = network.newInvoker();
        NetworkInvoker invoker2 = network.newInvoker();

        output = invoker1.invoke(input);

        /*
         * Trained networks can be saved to file, and then used later
         */
        File file = new File("somePath");
        network.writeTo(file);
        ArtificialNeuralNetwork network2 = ArtificialNeuralNetwork.from(file);
        /*
         * It's also possible to specify a (different) matrix factory when reading a network from file.
         */
        ArtificialNeuralNetwork network3 = ArtificialNeuralNetwork.from(RawStore.FACTORY, file);

        /*
         * What about specifying the error/loss function when training? ojAlgo supports 2 different error
         * functions, and which to use is dictated by the activator of the final layer. CROSS_ENTROPY has to
         * be used with the SOFTMAX activator, and cannot be used with any other. The correct error function
         * is set for you. You can manually set it, but if you set the incorrect one you'll get an exception.
         */
        trainer.error(CROSS_ENTROPY);
        trainer.error(HALF_SQUARED_DIFFERENCE);
    }

}