Skip to content

Instantly share code, notes, and snippets.

@AbdelmajidB
Created February 22, 2020 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AbdelmajidB/49ae82743d8d27046397283c366db4a1 to your computer and use it in GitHub Desktop.
Save AbdelmajidB/49ae82743d8d27046397283c366db4a1 to your computer and use it in GitHub Desktop.
public class SemanticSegmentation {
private static final Logger log = LoggerFactory.getLogger(SemanticSegmentation.class);
private static int height = 256;
private static int width = 256;
private static int channels = 1;
private static int batchSize = 6;
private static long seed = 1234;
private static Random rng = new Random(seed);
private static int epochs = 100;
public static final String DATA_PATH= "C:/Users/bismi/Documents/dl4j/brain-tumor-segmentation/src/main/resources";
public void run(String[] args) throws Exception {
//CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true);
//CudaEnvironment.getInstance().getConfiguration().useDevice(0);
log.info("Load data....");
MyPathLabelGenerator labelMaker = new MyPathLabelGenerator(DATA_PATH);
File trainData = new File(DATA_PATH+"/trainT/images/");
File testData = new File(DATA_PATH+"/testT/images/");
// Define the FileSplit(PATH, ALLOWED FORMATS,random)
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng);
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng);
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(train);
//recordReader.setListeners(new LogRecordListener());
int labelIndex = 1; //You have 2 Writables ("columns") - index 0 is features image NDArrayWritable, index 1 is labels image NDArrayWritable
// DataSet Iterator
DataSetIterator dataTrainIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSetIterator dataTestIter = new RecordReaderDataSetIterator(recordReader, 1,labelIndex, labelIndex, true);
//Normalize the data
NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler (0, 1);
scaler.fitLabel(true);
scaler.fit(dataTrainIter);
dataTrainIter.setPreProcessor(scaler);
log.info("Build model....");
ComputationGraphConfiguration.GraphBuilder graph = unetBuilder();
graph.addInputs("input").setInputTypes(InputType.convolutional(inputShape[0], inputShape[1], inputShape[2]));
ComputationGraphConfiguration conf = graph.build();
ComputationGraph model = new ComputationGraph(conf);
model.init();
UIServer uiServer=UIServer.getInstance();
StatsStorage statsStorage=new InMemoryStatsStorage();
uiServer.attach(statsStorage);
model.setListeners(new StatsListener(statsStorage));
//model.setListeners(new ScoreIterationListener(1));
log.info("Train model....");
int i;
BufferedWriter bw = null;
FileWriter fw = null;
File file=new File("C:\\filename.txt");
file.delete();
fw = new FileWriter(file,true);
bw = new BufferedWriter(fw);
Instant start = Instant.now();
for (i=0;i<epochs;i++) {
model.fit(dataTrainIter);
}
Instant finish = Instant.now();
log.info("Training time(Minutes): s:"+start.toString()+" f:"+finish.toString()+" duration:"+Duration.between(start, finish).toMinutes());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment