Created
May 7, 2019 12:29
-
-
Save tzolov/60d3b0a06897ec33069eadf6b6def8b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.mindmodel.services.image.recognition; | |
import java.io.UnsupportedEncodingException; | |
import java.nio.ByteBuffer; | |
import java.nio.charset.Charset; | |
import java.util.Arrays; | |
import java.util.Collections; | |
import java.util.List; | |
import java.util.function.BiFunction; | |
import com.google.protobuf.ByteString; | |
import io.mindmodel.services.common.AutoCloseableSession; | |
import org.tensorflow.Shape; | |
import org.tensorflow.Tensor; | |
import org.tensorflow.example.BytesList; | |
import org.tensorflow.example.Example; | |
import org.tensorflow.example.Feature; | |
import org.tensorflow.example.Features; | |
import org.tensorflow.op.Ops; | |
import org.tensorflow.op.core.Empty; | |
import org.tensorflow.op.core.Placeholder; | |
import org.tensorflow.op.io.ParseExample; | |
/** | |
* @author Christian Tzolov | |
*/ | |
public class Test1 { | |
public static class ParseExampleGraph extends AutoCloseableSession implements BiFunction<String, Example, List<Tensor<?>>> { | |
private ParseExample parser; | |
private static Empty<String> emptyString(Ops tf) { | |
return tf.empty(tf.constant(new int[] { 1 }), String.class); | |
} | |
@Override | |
protected void doGraphDefinition(Ops tf) { | |
Placeholder<String> examples = tf.withName("examples").placeholder(String.class, Placeholder.shape(Shape.make(1))); | |
Placeholder<String> names = tf.withName("names").placeholder(String.class, Placeholder.shape(Shape.make(1))); | |
this.parser = tf.io.parseExample( | |
examples, | |
names, | |
Collections.emptyList(), | |
Arrays.asList(tf.constant(FEATURE_A), tf.constant(FEATURE_B)), | |
Arrays.asList(emptyString(tf), emptyString(tf)), | |
Collections.emptyList(), | |
Arrays.asList(Shape.make(1), Shape.make(1)) | |
); | |
} | |
@Override | |
public List<Tensor<?>> apply(String name, Example example) { | |
ByteBuffer exampleData = Buffers.stringToBuffer(example); | |
ByteBuffer exampleName = Buffers.stringToBuffer(name); | |
try (Tensor<String> exampleTensor = Tensor.create(String.class, new long[] { 1 }, exampleData); | |
Tensor<String> nameTensor = Tensor.create(String.class, new long[] { 1 }, exampleName)) { | |
List<Tensor<?>> featureValues = this.getSession().runner() | |
.feed("examples", exampleTensor) | |
.feed("names", nameTensor) | |
.fetch(this.parser.denseValues().get(0)) | |
.fetch(this.parser.denseValues().get(1)) | |
.run(); | |
return featureValues; | |
} | |
} | |
} | |
public static void main(String[] args) throws UnsupportedEncodingException { | |
List<Tensor<?>> featureValues = new ParseExampleGraph().apply("Example", buildExample()); | |
for (Tensor<?> featureValue : featureValues) { | |
// Tricky way to read back a string from a tensor, | |
// each feature is a single list (dim 0) of one string (dim 1) of a variable length (dim 2) | |
byte[][][] value = new byte[1][1][]; | |
featureValue.copyTo(value); | |
System.out.println(new String(value[0][0], "UTF-8")); | |
featureValue.close(); | |
} | |
} | |
// Building test fixture | |
private static final String FEATURE_A = "featureA"; | |
private static final String FEATURE_B = "featureB"; | |
private static Example buildExample() { | |
Feature featureA = Feature.newBuilder() | |
.setBytesList(BytesList.newBuilder() | |
.addValue(ByteString.copyFrom("A feature", Charset.forName("UTF-8"))).build()) | |
.build(); | |
Feature featureB = Feature.newBuilder() | |
.setBytesList(BytesList.newBuilder() | |
.addValue(ByteString.copyFrom("Another feature", Charset.forName("UTF-8"))).build()) | |
.build(); | |
Features features = Features.newBuilder() | |
.putFeature(FEATURE_A, featureA) | |
.putFeature(FEATURE_B, featureB) | |
.build(); | |
return Example.newBuilder().setFeatures(features).build(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment