Skip to content

Instantly share code, notes, and snippets.

@KexinFeng
Last active December 16, 2022 17:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KexinFeng/d9c0a244d0597e6c6e161c1c1c2db569 to your computer and use it in GitHub Desktop.
Save KexinFeng/d9c0a244d0597e6c6e161c1c1c2db569 to your computer and use it in GitHub Desktop.
Transfer learning experiment with FreshFruit dataset: the reduction of the training data size.

This gist demonstrates the experiment code that measures the validation accuracy v.s. training dataset size. The result aims to test, with transfer learning, how much the training data size can be reduced. When the required dataset size is small, the data annotation cost is saved.

This demo example is built with gradle.

The project structure

TransferFreshFruitTrainExperiment.java is stored in directory "src/main/java".

Recource dependency

The dataset is accessible from Kaggle. You also need to set "localPath" variable in Line 136 to specify data path.

See the medium article: https://pub.towardsai.net/blazing-fast-training-with-small-dataset-for-java-applications-4acb9332cd0b

plugins {
id 'java'
}
repositories {
mavenCentral()
}
dependencies {
implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.17.1"
implementation platform("ai.djl:bom:0.20.0")
implementation "ai.djl:api"
runtimeOnly "ai.djl.pytorch:pytorch-engine"
runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
}
test {
useJUnitPlatform()
}
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.cv.classification.FruitsFreshAndRotten;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.OneHot;
import ai.djl.modality.cv.transform.RandomFlipLeftRight;
import ai.djl.modality.cv.transform.RandomFlipTopBottom;
import ai.djl.modality.cv.transform.RandomResizedCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
public final class TransferFreshFruitTrainExperiment {
private TransferFreshFruitTrainExperiment() {}
public static void main(String[] args)
throws IOException, TranslateException, ModelException, URISyntaxException {
TransferFreshFruitTrainExperiment.runExample(args);
}
public static TrainingResult runExample(String[] args)
throws IOException, TranslateException, ModelException, URISyntaxException {
String[] fruits = {"apple", "banana", "orange"};
boolean[] trainParams = {true, false};
TrainingResult result = null;
for (String fruit : fruits) {
for (boolean trainParam : trainParams) {
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelUrls("djl://ai.djl.pytorch/resnet18_embedding")
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.optOption("trainParam", String.valueOf(trainParam))
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();
Block baseBlock = embedding.getBlock();
Block blocks =
new SequentialBlock()
.add(baseBlock)
.addSingleton(nd -> nd.squeeze(new int[] {2, 3}))
.add(Linear.builder().setUnits(2).build()) // linear on which dim?
.addSingleton(nd -> nd.softmax(1));
Model model = Model.newInstance("TransferFreshFruit");
model.setBlock(blocks);
try (NDManager manager = NDManager.newBaseManager()) {
NDArray cuts = manager.arange(10, 110, 10);
NDArray cut2 = manager.create(new int[] {2});
cuts = cut2.concat(cuts);
float[] accuracy = new float[(int) cuts.size()];
int cnt = 0;
for (long cut : cuts.toIntArray()) {
DefaultTrainingConfig config = setupTrainingConfig(baseBlock);
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
int batchSize = 32;
trainer.initialize(new Shape(batchSize, 3, 224, 224));
RandomAccessDataset datasetTrain = getData(fruit, "train", batchSize, cut);
RandomAccessDataset datasetTest = getData(fruit, "test", batchSize, 0);
EasyTrain.fit(trainer, 49, datasetTrain, null);
EasyTrain.fit(trainer, 1, datasetTrain, datasetTest);
result = trainer.getTrainingResult();
accuracy[cnt++] = result.getValidateEvaluation("Accuracy");
}
cuts.setName("cuts_" + fruit + "_" + trainParam);
NDArray acc = manager.create(accuracy, new Shape(accuracy.length));
acc.setName("accuracy_" + fruit + "_" + trainParam);
saveNDArray(cuts);
saveNDArray(acc);
}
model.close();
embedding.close();
}
}
return result;
}
private static void saveNDArray(NDArray array) throws IOException {
Path path = Paths.get("build").resolve(array.getName() + ".npz");
try (OutputStream os = Files.newOutputStream(path)) {
new NDList(new NDList(array)).encode(os, true);
}
}
private static RandomAccessDataset getData(String fruit, String usage, int batchSize, long cut)
throws TranslateException, IOException {
// The dataset is accessible from:
// https://www.kaggle.com/datasets/sriramr/fruits-fresh-and-rotten-for-classification
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
String localPath = "localPath";
Repository repository =
Repository.newInstance(
fruit, Paths.get(localPath + "/" + fruit + "_full/" + usage));
FruitsFreshAndRotten dataset =
FruitsFreshAndRotten.builder()
.optRepository(repository)
.addTransform(new RandomResizedCrop(256, 256)) // only in training
.addTransform(new RandomFlipTopBottom()) // only in training
.addTransform(new RandomFlipLeftRight()) // only in training
.addTransform(new Resize(256, 256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(mean, std))
.addTargetTransform(new OneHot(2))
.setSampling(batchSize, true)
.build();
dataset.prepare();
if (cut > 0) {
List<Long> batchIndexList = new ArrayList<>();
try (NDManager manager = NDManager.newBaseManager()) {
NDArray indices = manager.randomPermutation(dataset.size());
NDArray batchIndex = indices.get(":{}", cut);
for (long index : batchIndex.toLongArray()) {
batchIndexList.add(index);
}
}
return dataset.subDataset(batchIndexList);
}
return dataset;
}
private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) {
String outputDir = "build/fruits";
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
DefaultTrainingConfig config =
new DefaultTrainingConfig(new SoftmaxCrossEntropy("SoftmaxCrossEntropy"))
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(1))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
// Customized learning rate
float lr = 0.0001f;
FixedPerVarTracker.Builder learningRateTrackerBuilder =
FixedPerVarTracker.builder().setDefaultValue(lr);
for (Pair<String, Parameter> paramPair : baseBlock.getParameters()) {
learningRateTrackerBuilder.put(paramPair.getValue().getId(), 0.1f * lr);
}
FixedPerVarTracker learningRateTracker = learningRateTrackerBuilder.build();
Optimizer optimizer = Adam.builder().optLearningRateTracker(learningRateTracker).build();
config.optOptimizer(optimizer);
return config;
}
private static class SoftmaxCrossEntropy extends Loss {
/**
* Base class for metric with abstract update methods.
*
* @param name The display name of the Loss
*/
public SoftmaxCrossEntropy(String name) {
super(name);
}
/** {@inheritDoc} */
@Override
public NDArray evaluate(NDList labels, NDList predictions) {
// Here the labels are supposed to be one-hot
int classAxis = -1;
NDArray pred = predictions.singletonOrThrow().log();
NDArray lab = labels.singletonOrThrow().reshape(pred.getShape());
NDArray loss = pred.mul(lab).neg().sum(new int[] {classAxis}, true);
return loss.mean();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment