Created
October 4, 2021 16:50
-
-
Save AbdelmajidB/63cc58e46e3b6cd3ed867294ebbff158 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private static final Logger log = LoggerFactory.getLogger(SemanticSegmentation.class); | |
private static int height = 96; | |
private static int width = 96; | |
private static int channels = 1; | |
private static int batchSize = 22; | |
private static long seed = 1234; | |
private static Random rng = new Random(seed); | |
private static int epochs = 1; | |
public static final String DATA_PATH= "C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources"; | |
public void run(String[] args) throws Exception { | |
log.info("Load data...."); | |
MyPathLabelGenerator labelMaker = new MyPathLabelGenerator(DATA_PATH); | |
File trainData = new File(DATA_PATH+"/trainI/images/"); | |
File testData = new File(DATA_PATH+"/testI/images/"); | |
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng); | |
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng); | |
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); | |
recordReader.initialize(train); | |
int labelIndex = 1; //You have 2 Writables ("columns") - index 0 is features image NDArrayWritable, index 1 is labels image NDArrayWritable | |
// DataSet Iterator | |
DataSetIterator dataTrainIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true); | |
DataSetIterator dataTestIter = new RecordReaderDataSetIterator(recordReader, 1,labelIndex, labelIndex, true); | |
//Normalize the data | |
ImagePreProcessingScaler1 scaler = new ImagePreProcessingScaler1(0, 1); | |
dataTrainIter.setPreProcessor(scaler); | |
log.info("Build model...."); | |
ComputationGraphConfiguration.GraphBuilder graph = unetBuilder(); | |
graph.addInputs("input").setInputTypes(InputType.convolutional(inputShape[0], inputShape[1], inputShape[2])); | |
ComputationGraphConfiguration conf = graph.build(); | |
ComputationGraph model = new ComputationGraph(conf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(100)); | |
log.info("Train model...."); | |
Instant start = Instant.now(); | |
for (i=0;i<epochs;i++) { | |
model.fit(dataTrainIter); | |
} | |
Instant finish = Instant.now(); | |
log.info("Training time(Minutes): s:"+start.toString()+" f:"+finish.toString()+" duration:"+Duration.between(start, finish).toMinutes()); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment