Skip to content

Instantly share code, notes, and snippets.

@zaleslaw
zaleslaw / dataLoading.kt
Created June 9, 2022 13:13
Loads data to the dataset
val trainDataset = OnHeapDataset.create(File(datasetPath, "train"), labelGenerator, preprocessing)
val valDataset = OnHeapDataset.create(File(datasetPath, "val"), labelGenerator, preprocessing)
@zaleslaw
zaleslaw / imagePreprocessingDSL.kt
Created June 9, 2022 13:12
Image Preprocessing DSL
val preprocessing = preprocess {
transformImage {
centerCrop {
size = 214
}
pad {
top = 10
bottom = 10
left = 10
right = 10
@zaleslaw
zaleslaw / callbacksUsage.kt
Created June 9, 2022 13:03
How to use mutiple callbacks
model.use {
it.compile(
optimizer = Adam(clipGradient = ClipGradientByValue(0.1f)),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.logSummary()
it.fit(
val earlyStopping = EarlyStopping(
monitor = EpochTrainingEvent::valLossValue,
minDelta = 0.0,
patience = 2,
verbose = true,
mode = EarlyStoppingMode.AUTO,
baseline = 0.1,
restoreBestWeights = false
)
val terminateOnNaN = TerminateOnNaN()
@zaleslaw
zaleslaw / trainingNoTopModel.kt
Created June 9, 2022 12:34
Trains noTop model
model.use {
it.compile(
optimizer = Adam(),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.loadWeightsForFrozenLayers(hdfFile)
it.fit(
@zaleslaw
zaleslaw / topModel.kt
Created June 9, 2022 12:33
Top model description
val topModel = Sequential.of(
GlobalAvgPool2D(
name = "top_avg_pool",
),
Dense(
name = "top_dense",
kernelInitializer = GlorotUniform(),
biasInitializer = GlorotUniform(),
outputSize = 200,
activation = Activations.Relu
@zaleslaw
zaleslaw / resnet50NoTop.kt
Created June 9, 2022 12:31
Loads noTop model
val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels"))
val modelType = TFModels.CV.ResNet50(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
val noTopModel = modelHub.loadModel(modelType)
@zaleslaw
zaleslaw / multiplePoseDetection.kt
Created June 9, 2022 09:50
Detects multiple poses
model.use { poseDetectionModel ->
val imageFile = …
val detectedPoses = poseDetectionModel.detectPoses(imageFile = imageFile, confidence = 0.0f)
detectedPoses.multiplePoses.forEach { detectedPose ->
println("Found ${detectedPose.first.classLabel} with probability ${detectedPose.first.probability}")
detectedPose.second.poseLandmarks.forEach {
println("Found ${it.poseLandmarkLabel} with probability ${it.probability}")
}
@zaleslaw
zaleslaw / loadMultiPoseModel.kt
Created June 9, 2022 09:49
Loads the MoveNetMultiPoseLighting model
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
val model = ONNXModels.PoseDetection.MoveNetMultiPoseLighting.pretrainedModel(modelHub)
@zaleslaw
zaleslaw / detectPose.kt
Created June 9, 2022 09:45
Pose detection
model.use { poseDetectionModel ->
val imageFile = …
val detectedPose = poseDetectionModel.detectPose(imageFile = imageFile)
detectedPose.poseLandmarks.forEach {
println("Found ${it.poseLandmarkLabel} with probability ${it.probability}")
}
detectedPose.edges.forEach {
println("The ${it.poseEdgeLabel} starts at ${it.start.poseLandmarkLabel} and ends with ${it.end.poseLandmarkLabel}")