-
-
Save wmeddie/809c602311503a2ebb2d2cd957daf09e to your computer and use it in GitHub Desktop.
SameDiffLayer idea
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.samediff.testlayers; | |
import lombok.Builder; | |
import lombok.Data; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
import org.nd4j.weightinit.WeightInitScheme; | |
import org.nd4j.weightinit.impl.UniformInitScheme; | |
import org.nd4j.weightinit.impl.XavierInitScheme; | |
import java.util.Arrays; | |
import java.util.Collections; | |
import java.util.List; | |
public class MyCustomLayer extends SameDiffLayer { | |
private final int nOut; | |
private SDVariable W1; | |
private SDVariable W2; | |
private SDVariable b; | |
public MyCustomLayer(int nOut) { | |
this.nOut = nOut; | |
} | |
@Override | |
public void buildParameters(SameDiff sd, int[] inputShape) { | |
WeightInitScheme init = UniformInitScheme.builder().build(); | |
W1 = addWeight(sd, "W1", new int[] { inputShape[1], nOut }, init, true); | |
W2 = addWeight(sd, "W2", new int[] { nOut, nOut }, init, true); | |
b = addBias(sd, "b", new int[] { 1, nOut }, 1.0, true); | |
} | |
@Override | |
public List<SDVariable> defineFeedForward(SameDiff sd, SDVariable layerInput) { | |
SDVariable mmul1 = sd.mmul(layerInput, W1); | |
SDVariable mmul2 = sd.mmul(mmul1, W2); | |
SDVariable z = mmul2.add(b); | |
if (getActivation() != null) { | |
return Collections.singletonList(getActivation().asSameDiff(sd, z)); | |
} else { | |
return Collections.singletonList(z); | |
} | |
} | |
@Override | |
public InputType getOutputType(int layerIndex, InputType inputType) { | |
return InputType.feedForward(nOut); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.samediff.testlayers; | |
import lombok.Data; | |
import lombok.EqualsAndHashCode; | |
import org.deeplearning4j.nn.conf.InputPreProcessor; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.samediff.BaseSameDiffLayer; | |
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayerUtils; | |
import org.deeplearning4j.nn.params.DefaultParamInitializer; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.nn.weights.WeightInitUtil; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.activations.IActivation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.weightinit.WeightInitScheme; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.stream.Collectors; | |
@Data | |
@EqualsAndHashCode(callSuper = true) | |
public abstract class SameDiffLayer extends BaseSameDiffLayer { | |
private int nIn = 0; | |
private int nOut = 0; | |
private int outSize = 0; | |
private Map<String, SDVariable> weights = new HashMap<>(); | |
private Map<String, SDVariable> biases = new HashMap<>(); | |
private Map<String, Double> biasInits = new HashMap<>(); | |
private Activation activation; | |
private boolean built = false; | |
public SameDiffLayer() {} | |
@Override | |
public void setNIn(InputType inputType, boolean override) { | |
} | |
@Override | |
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { | |
return null; | |
} | |
public SameDiffLayer(int nOut) { | |
this.nOut = nOut; | |
} | |
public SDVariable addWeight(SameDiff sd, String name, int[] shape, WeightInitScheme init, boolean trainable) { | |
SDVariable var = sd.var(name, shape, init); | |
if (trainable) { | |
weights.put(name, var); | |
} | |
return var; | |
} | |
public SDVariable addBias(SameDiff sd, String name, int[] shape, double biasInit, boolean trainable) { | |
SDVariable var = sd.var(name, shape); | |
if (trainable) { | |
biases.put(name, var); | |
} | |
return var; | |
} | |
public abstract void buildParameters(SameDiff sd, int[] inputShape); | |
public abstract List<SDVariable> defineFeedForward(SameDiff sd, SDVariable layerInput); | |
@Override | |
public List<String> defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable) { | |
if (!built) { | |
buildParameters(sameDiff, layerInput.getShape()); | |
} | |
for (Map.Entry<String, SDVariable> entry : weights.entrySet()) { | |
paramTable.put(entry.getKey(), entry.getValue()); | |
} | |
for (Map.Entry<String, SDVariable> entry : biases.entrySet()) { | |
paramTable.put(entry.getKey(), entry.getValue()); | |
} | |
List<SDVariable> sdVariables = defineFeedForward(sameDiff, layerInput); | |
return sdVariables.stream().map(SDVariable::getVarName).collect(Collectors.toList()); | |
} | |
@Override | |
public List<String> weightKeys() { | |
return new ArrayList<>(weights.keySet()); | |
} | |
@Override | |
public List<String> biasKeys() { | |
return new ArrayList<>(biases.keySet()); | |
} | |
@Override | |
public Map<String, int[]> paramShapes() { | |
Map<String, int[]> ret = new HashMap<>(); | |
for (Map.Entry<String, SDVariable> entry : weights.entrySet()) { | |
ret.put(entry.getKey(), entry.getValue().getShape()); | |
} | |
for (Map.Entry<String, SDVariable> entry : biases.entrySet()) { | |
ret.put(entry.getKey(), entry.getValue().getShape()); | |
} | |
return ret; | |
} | |
@Override | |
public void initializeParams(Map<String, INDArray> params) { | |
for (Map.Entry<String,INDArray> e : params.entrySet()) { | |
if (weights.containsKey(e.getKey())) { | |
SDVariable weightVar = weights.get(e.getKey()); | |
WeightInitScheme weightInitScheme = weightVar.getWeightInitScheme(); | |
org.nd4j.weightinit.WeightInit type = weightInitScheme.type(); | |
WeightInit init = WeightInit.valueOf(type.name()); | |
WeightInitUtil.initWeights(nIn, nOut, new int[] { nIn, nOut }, init, null, 'f', e.getValue()); | |
} else if (biases.containsKey(e.getKey())) { | |
e.getValue().assign(biasInits.get(e.getKey())); | |
} | |
} | |
} | |
@Override | |
public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { | |
if(activation == null){ | |
activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn()); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.samediff; | |
import lombok.extern.slf4j.Slf4j; | |
import org.deeplearning4j.TestUtils; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.gradient.Gradient; | |
import org.deeplearning4j.nn.layers.custom.testclasses.CustomLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.params.DefaultParamInitializer; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.samediff.testlayers.MyCustomLayer; | |
import org.deeplearning4j.samediff.testlayers.SameDiffDense; | |
import org.junit.Test; | |
import org.nd4j.autodiff.samediff.SDVariable; | |
import org.nd4j.autodiff.samediff.SameDiff; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.blas.params.MMulTranspose; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.nd4j.linalg.ops.transforms.Transforms; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.Map; | |
import static org.junit.Assert.assertArrayEquals; | |
import static org.junit.Assert.assertEquals; | |
import static org.junit.Assert.assertNotNull; | |
@Slf4j | |
public class TestSameDiffCustom { | |
@Test | |
public void testMyCustomLayer() { | |
int nIn = 3; | |
int nOut = 4; | |
int nMinibatch = 5; | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.list() | |
.layer(new MyCustomLayer(nOut)) | |
.build(); | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
INDArray output = net.output(Nd4j.rand(nMinibatch, nIn)); | |
assertArrayEquals(new int[] { nMinibatch, nOut }, output.shape()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment