Created
February 22, 2020 18:07
-
-
Save AbdelmajidB/49ae82743d8d27046397283c366db4a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class SemanticSegmentation { | |
private static final Logger log = LoggerFactory.getLogger(SemanticSegmentation.class); | |
private static int height = 256; | |
private static int width = 256; | |
private static int channels = 1; | |
private static int batchSize = 6; | |
private static long seed = 1234; | |
private static Random rng = new Random(seed); | |
private static int epochs = 100; | |
public static final String DATA_PATH= "C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources"; | |
public void run(String[] args) throws Exception { | |
//CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true); | |
//CudaEnvironment.getInstance().getConfiguration().useDevice(0); | |
log.info("Load data...."); | |
MyPathLabelGenerator labelMaker = new MyPathLabelGenerator(DATA_PATH); | |
File trainData = new File(DATA_PATH+"/trainT/images/"); | |
File testData = new File(DATA_PATH+"/testT/images/"); | |
// Define the FileSplit(PATH, ALLOWED FORMATS,random) | |
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng); | |
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng); | |
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); | |
recordReader.initialize(train); | |
//recordReader.setListeners(new LogRecordListener()); | |
int labelIndex = 1; //You have 2 Writables ("columns") - index 0 is features image NDArrayWritable, index 1 is labels image NDArrayWritable | |
// DataSet Iterator | |
DataSetIterator dataTrainIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true); | |
DataSetIterator dataTestIter = new RecordReaderDataSetIterator(recordReader, 1,labelIndex, labelIndex, true); | |
//Normalize the data | |
NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler (0, 1); | |
scaler.fitLabel(true); | |
scaler.fit(dataTrainIter); | |
dataTrainIter.setPreProcessor(scaler); | |
log.info("Build model...."); | |
ComputationGraphConfiguration.GraphBuilder graph = unetBuilder(); | |
graph.addInputs("input").setInputTypes(InputType.convolutional(inputShape[0], inputShape[1], inputShape[2])); | |
ComputationGraphConfiguration conf = graph.build(); | |
ComputationGraph model = new ComputationGraph(conf); | |
model.init(); | |
UIServer uiServer=UIServer.getInstance(); | |
StatsStorage statsStorage=new InMemoryStatsStorage(); | |
uiServer.attach(statsStorage); | |
model.setListeners(new StatsListener(statsStorage)); | |
//model.setListeners(new ScoreIterationListener(1)); | |
log.info("Train model...."); | |
int i; | |
BufferedWriter bw = null; | |
FileWriter fw = null; | |
File file=new File("C:\\filename.txt"); | |
file.delete(); | |
fw = new FileWriter(file,true); | |
bw = new BufferedWriter(fw); | |
Instant start = Instant.now(); | |
for (i=0;i<epochs;i++) { | |
model.fit(dataTrainIter); | |
} | |
Instant finish = Instant.now(); | |
log.info("Training time(Minutes): s:"+start.toString()+" f:"+finish.toString()+" duration:"+Duration.between(start, finish).toMinutes()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment