Last active
September 4, 2015 07:48
-
-
Save cerisara/db2193e68ee3c33b7182 to your computer and use it in GitHub Desktop.
Alternative flattening/deflattening of parameters in DL4J / Spark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Iterative reduce with | |
* flat map using map partitions | |
* | |
* @author Adam Gibson | |
modified by Christophe Cerisara | |
*/ | |
public class IterativeReduceFlatMap implements FlatMapFunction<Iterator<DataSet>, INDArray> { | |
private String json; | |
private Broadcast<INDArray> params; | |
private static Logger log = LoggerFactory.getLogger(IterativeReduceFlatMap.class); | |
/** | |
* Pass in json configuration and baseline parameters | |
* | |
* @param json json configuration for the network | |
* @param params the parameters to use for the network | |
*/ | |
public IterativeReduceFlatMap(String json, Broadcast<INDArray> params) { | |
this.json = json; | |
this.params = params; | |
} | |
public static INDArray flattenParms(MultiLayerNetwork network) { | |
ArrayList<Double> ps = new ArrayList<Double>(); | |
final int nl=network.getLayers().length; | |
ps.add((double)nl); | |
for (int i=0;i<nl;i++) { | |
INDArray pp = network.getLayer(i).params(); | |
final int nx=pp.length(); | |
ps.add((double)nx); | |
for (int j=0;j<nx;j++) { | |
ps.add(pp.getDouble(j)); | |
} | |
} | |
double[] vs = new double[ps.size()]; | |
for (int i=0;i<vs.length;i++) vs[i]=ps.get(i); | |
INDArray v = Nd4j.create(vs); | |
return v; | |
} | |
public static void deflattenParms(MultiLayerNetwork network, INDArray parms) { | |
int pi=0; | |
int nLayer = (int)parms.getDouble(pi++); | |
assert nLayer == network.getLayers().length; | |
for (int i=0;i<nLayer;i++) { | |
int nx=(int)parms.getDouble(pi++); | |
double[] w = new double[nx]; | |
for (int j=0;j<nx;j++) w[j] = parms.getDouble(pi++); | |
INDArray pl = Nd4j.create(w); | |
network.getLayer(i).setParams(pl); | |
} | |
} | |
@Override | |
public Iterable<INDArray> call(Iterator<DataSet> dataSetIterator) throws Exception { | |
if (!dataSetIterator.hasNext()) { | |
return Collections.singletonList(Nd4j.zeros(params.value().shape())); | |
} | |
List<DataSet> collect = new ArrayList<DataSet>(); | |
while (dataSetIterator.hasNext()) { | |
collect.add(dataSetIterator.next()); | |
} | |
DataSet data = DataSet.merge(collect, false); | |
log.debug("Training on " + data.labelCounts()); | |
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json)); | |
network.init(); | |
INDArray val = params.value(); | |
if (val.length() != network.numParams()) | |
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters"); | |
network.setParameters(val); | |
network.fit(data); | |
INDArray trainedParms = flattenParms(network); | |
return Collections.singletonList(trainedParms); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Master class for spark | |
* | |
* @author Adam Gibson | |
modified by Christophe Cerisara | |
*/ | |
public class SparkDl4jMultiLayer implements Serializable { | |
private transient SparkContext sparkContext; | |
private transient JavaSparkContext sc; | |
private MultiLayerConfiguration conf; | |
private MultiLayerNetwork network; | |
private Broadcast<INDArray> params; | |
private boolean averageEachIteration = false; | |
public final static String AVERAGE_EACH_ITERATION = "org.deeplearning4j.spark.iteration.average"; | |
private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class); | |
/** | |
* Instantiate a multi layer spark instance | |
* with the given context and network. | |
* This is the prediction constructor | |
* | |
* @param sparkContext the spark context to use | |
* @param network the network to use | |
*/ | |
public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork network) { | |
this.sparkContext = sparkContext; | |
this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false); | |
this.network = network; | |
this.conf = this.network.getLayerWiseConfigurations().clone(); | |
sc = new JavaSparkContext(this.sparkContext); | |
this.params = sc.broadcast(network.params()); | |
} | |
/** | |
* Training constructor. Instantiate with a configuration | |
* | |
* @param sparkContext the spark context to use | |
* @param conf the configuration of the network | |
*/ | |
public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration conf) { | |
this.sparkContext = sparkContext; | |
this.conf = conf.clone(); | |
this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false); | |
sc = new JavaSparkContext(this.sparkContext); | |
} | |
/** | |
* Training constructor. Instantiate with a configuration | |
* | |
* @param sc the spark context to use | |
* @param conf the configuration of the network | |
*/ | |
public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf) { | |
this(sc.sc(), conf); | |
} | |
/** | |
* Train a multi layer network based on the path | |
* | |
* @param path the path to the text file | |
* @param labelIndex the label index | |
* @param recordReader the record reader to parse results | |
* @return {@link MultiLayerNetwork} | |
*/ | |
public MultiLayerNetwork fit(String path, int labelIndex, RecordReader recordReader) { | |
JavaRDD<String> lines = sc.textFile(path); | |
// gotta map this to a Matrix/INDArray | |
FeedForwardLayer outputLayer = (FeedForwardLayer) conf.getConf(conf.getConfs().size() - 1).getLayer(); | |
JavaRDD<DataSet> points = lines.map(new RecordReaderFunction(recordReader | |
, labelIndex, outputLayer.getNOut())); | |
return fitDataSet(points); | |
} | |
public MultiLayerNetwork getNetwork() { | |
return network; | |
} | |
public void setNetwork(MultiLayerNetwork network) { | |
this.network = network; | |
} | |
/** | |
* Predict the given feature matrix | |
* | |
* @param features the given feature matrix | |
* @return the predictions | |
*/ | |
public Matrix predict(Matrix features) { | |
return MLLibUtil.toMatrix(network.output(MLLibUtil.toMatrix(features))); | |
} | |
/** | |
* Predict the given vector | |
* | |
* @param point the vector to predict | |
* @return the predicted vector | |
*/ | |
public Vector predict(Vector point) { | |
return MLLibUtil.toVector(network.output(MLLibUtil.toVector(point))); | |
} | |
/** | |
* Fit the given rdd given the context. | |
* This will convert the labeled points | |
* to the internal dl4j format and train the model on that | |
* | |
* @param sc the org.deeplearning4j.spark context | |
* @param rdd the rdd to fitDataSet | |
* @return the multi layer network that was fitDataSet | |
*/ | |
public MultiLayerNetwork fit(JavaSparkContext sc, JavaRDD<LabeledPoint> rdd) { | |
FeedForwardLayer outputLayer = (FeedForwardLayer) conf.getConf(conf.getConfs().size() - 1).getLayer(); | |
return fitDataSet(MLLibUtil.fromLabeledPoint(sc, rdd, outputLayer.getNOut())); | |
} | |
/** | |
* Fit the dataset rdd | |
* | |
* @param rdd the rdd to fitDataSet | |
* @return the multi layer network | |
*/ | |
public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> rdd) { | |
int iterations = conf.getConf(0).getNumIterations(); | |
log.info("Running distributed training averaging each iteration " + averageEachIteration + " and " + rdd.partitions().size() + " partitions"); | |
if (!averageEachIteration) | |
runIteration(rdd); | |
else { | |
for (NeuralNetConfiguration conf : this.conf.getConfs()) | |
conf.setNumIterations(1); | |
MultiLayerNetwork network = new MultiLayerNetwork(conf); | |
network.init(); | |
final INDArray params = network.params(); | |
this.params = sc.broadcast(params); | |
for (int i = 0; i < iterations; i++) | |
runIteration(rdd); | |
} | |
return network; | |
} | |
private void runIteration(JavaRDD<DataSet> rdd) { | |
MultiLayerNetwork network = new MultiLayerNetwork(conf); | |
network.init(); | |
final INDArray params = network.params(); | |
this.params = sc.broadcast(params); | |
log.info("Broadcasting initial parameters of length " + params.length()); | |
int paramsLength = network.numParams(); | |
if (params.length() != paramsLength) | |
throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length()); | |
// pourquoi ne sampler que env. 40% du corpus pour le train ? | |
// JavaRDD<DataSet> c = rdd.sample(true, 0.4); | |
JavaRDD<DataSet> c = rdd; | |
JavaRDD<INDArray> results = c.mapPartitions(new IterativeReduceFlatMap(conf.toJson(), this.params)).cache(); | |
log.info("Ran iterative reduce...averaging results now."); | |
// some indices in the flattened parameters are not real parameters, but are size indicators. | |
// but there shouldn't be any problem with averaging these indicators as well, because they should be constant and the same | |
Adder a = new Adder(results.collect().get(0).length()); | |
results.foreach(a); | |
INDArray tmpParams = a.getAccumulator().value(); | |
log.info("Accumulated parameters"); | |
tmpParams.divi(rdd.partitions().size()); | |
log.info("Divided by partitions"); | |
IterativeReduceFlatMap.deflattenParms(network, tmpParams); | |
log.info("Set parameters"); | |
this.network = network; | |
} | |
/** | |
* Train a multi layer network | |
* | |
* @param data the data to train on | |
* @param conf the configuration of the network | |
* @return the fit multi layer network | |
*/ | |
public static MultiLayerNetwork train(JavaRDD<LabeledPoint> data, MultiLayerConfiguration conf) { | |
SparkDl4jMultiLayer multiLayer = new SparkDl4jMultiLayer(data.context(), conf); | |
return multiLayer.fit(new JavaSparkContext(data.context()), data); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment