Skip to content

Instantly share code, notes, and snippets.

@tzolov
Created May 7, 2019 12:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tzolov/60d3b0a06897ec33069eadf6b6def8b8 to your computer and use it in GitHub Desktop.
Save tzolov/60d3b0a06897ec33069eadf6b6def8b8 to your computer and use it in GitHub Desktop.
package io.mindmodel.services.image.recognition;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.BiFunction;
import com.google.protobuf.ByteString;
import io.mindmodel.services.common.AutoCloseableSession;
import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.example.BytesList;
import org.tensorflow.example.Example;
import org.tensorflow.example.Feature;
import org.tensorflow.example.Features;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Empty;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.io.ParseExample;
/**
* @author Christian Tzolov
*/
public class Test1 {
public static class ParseExampleGraph extends AutoCloseableSession implements BiFunction<String, Example, List<Tensor<?>>> {
private ParseExample parser;
private static Empty<String> emptyString(Ops tf) {
return tf.empty(tf.constant(new int[] { 1 }), String.class);
}
@Override
protected void doGraphDefinition(Ops tf) {
Placeholder<String> examples = tf.withName("examples").placeholder(String.class, Placeholder.shape(Shape.make(1)));
Placeholder<String> names = tf.withName("names").placeholder(String.class, Placeholder.shape(Shape.make(1)));
this.parser = tf.io.parseExample(
examples,
names,
Collections.emptyList(),
Arrays.asList(tf.constant(FEATURE_A), tf.constant(FEATURE_B)),
Arrays.asList(emptyString(tf), emptyString(tf)),
Collections.emptyList(),
Arrays.asList(Shape.make(1), Shape.make(1))
);
}
@Override
public List<Tensor<?>> apply(String name, Example example) {
ByteBuffer exampleData = Buffers.stringToBuffer(example);
ByteBuffer exampleName = Buffers.stringToBuffer(name);
try (Tensor<String> exampleTensor = Tensor.create(String.class, new long[] { 1 }, exampleData);
Tensor<String> nameTensor = Tensor.create(String.class, new long[] { 1 }, exampleName)) {
List<Tensor<?>> featureValues = this.getSession().runner()
.feed("examples", exampleTensor)
.feed("names", nameTensor)
.fetch(this.parser.denseValues().get(0))
.fetch(this.parser.denseValues().get(1))
.run();
return featureValues;
}
}
}
public static void main(String[] args) throws UnsupportedEncodingException {
List<Tensor<?>> featureValues = new ParseExampleGraph().apply("Example", buildExample());
for (Tensor<?> featureValue : featureValues) {
// Tricky way to read back a string from a tensor,
// each feature is a single list (dim 0) of one string (dim 1) of a variable length (dim 2)
byte[][][] value = new byte[1][1][];
featureValue.copyTo(value);
System.out.println(new String(value[0][0], "UTF-8"));
featureValue.close();
}
}
// Building test fixture
private static final String FEATURE_A = "featureA";
private static final String FEATURE_B = "featureB";
private static Example buildExample() {
Feature featureA = Feature.newBuilder()
.setBytesList(BytesList.newBuilder()
.addValue(ByteString.copyFrom("A feature", Charset.forName("UTF-8"))).build())
.build();
Feature featureB = Feature.newBuilder()
.setBytesList(BytesList.newBuilder()
.addValue(ByteString.copyFrom("Another feature", Charset.forName("UTF-8"))).build())
.build();
Features features = Features.newBuilder()
.putFeature(FEATURE_A, featureA)
.putFeature(FEATURE_B, featureB)
.build();
return Example.newBuilder().setFeatures(features).build();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment