Skip to content

Instantly share code, notes, and snippets.

@sjaiswal25
Created November 17, 2018 07:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sjaiswal25/fb3c0eb1fcb425856dfa4c5b479b4c28 to your computer and use it in GitHub Desktop.
Save sjaiswal25/fb3c0eb1fcb425856dfa4c5b479b4c28 to your computer and use it in GitHub Desktop.
public class YoloKerasModelImport {
public static int nBoxes = 5;
public static double [][] priorBoxes = {{0.57273, 0.677385}, {1.87446, 2.06253} ,
{3.33843, 5.47434}, {7.88282, 3.52778}, {9.77052, 9.16828}};
private static long seed;
private static WorkspaceMode workspaceMode;
public static void main(String[] args)
throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
String pretrainedModelPath = "/home/project/data/tiny_yolov2.h5";
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(pretrainedModelPath,false);
INDArray priors = Nd4j.create(priorBoxes);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).gradientNormalizationThreshold(1.0)
.updater(new Adam.Builder().learningRate(1e-3).build()).l2(0.00001).activation(Activation.IDENTITY)
.trainingWorkspaceMode(workspaceMode).inferenceWorkspaceMode(workspaceMode).build();
ComputationGraph model = new TransferLearning.GraphBuilder(graph).fineTuneConfiguration(fineTuneConf)
.addLayer("outputs", new Yolo2OutputLayer.Builder().boundingBoxPriors(priors).build(), "conv2d_9")
.setOutputs("outputs").build();
System.out.println(model.summary(InputType.convolutional(224, 224, 3)));
ModelSerializer.writeModel(model, "/home/project/data/tiny_yolov2.zip", false);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment