Skip to content

Instantly share code, notes, and snippets.

@AbdelmajidB
Created October 4, 2021 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AbdelmajidB/63cc58e46e3b6cd3ed867294ebbff158 to your computer and use it in GitHub Desktop.
Save AbdelmajidB/63cc58e46e3b6cd3ed867294ebbff158 to your computer and use it in GitHub Desktop.
private static final Logger log = LoggerFactory.getLogger(SemanticSegmentation.class);
private static int height = 96;
private static int width = 96;
private static int channels = 1;
private static int batchSize = 22;
private static long seed = 1234;
private static Random rng = new Random(seed);
private static int epochs = 1;
public static final String DATA_PATH= "C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources";
public void run(String[] args) throws Exception {
log.info("Load data....");
MyPathLabelGenerator labelMaker = new MyPathLabelGenerator(DATA_PATH);
File trainData = new File(DATA_PATH+"/trainI/images/");
File testData = new File(DATA_PATH+"/testI/images/");
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng);
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng);
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(train);
int labelIndex = 1; //You have 2 Writables ("columns") - index 0 is features image NDArrayWritable, index 1 is labels image NDArrayWritable
// DataSet Iterator
DataSetIterator dataTrainIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSetIterator dataTestIter = new RecordReaderDataSetIterator(recordReader, 1,labelIndex, labelIndex, true);
//Normalize the data
ImagePreProcessingScaler1 scaler = new ImagePreProcessingScaler1(0, 1);
dataTrainIter.setPreProcessor(scaler);
log.info("Build model....");
ComputationGraphConfiguration.GraphBuilder graph = unetBuilder();
graph.addInputs("input").setInputTypes(InputType.convolutional(inputShape[0], inputShape[1], inputShape[2]));
ComputationGraphConfiguration conf = graph.build();
ComputationGraph model = new ComputationGraph(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
log.info("Train model....");
Instant start = Instant.now();
for (i=0;i<epochs;i++) {
model.fit(dataTrainIter);
}
Instant finish = Instant.now();
log.info("Training time(Minutes): s:"+start.toString()+" f:"+finish.toString()+" duration:"+Duration.between(start, finish).toMinutes());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment