Skip to content

Instantly share code, notes, and snippets.

@tanapoln
Created June 12, 2018 10:31
Show Gist options
  • Save tanapoln/2b800d8adf14584e3f4587cb1f88491d to your computer and use it in GitHub Desktop.
Save tanapoln/2b800d8adf14584e3f4587cb1f88491d to your computer and use it in GitHub Desktop.
Tensorflow Java API with Sertis Word Segmentation
import java.util.AbstractList;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class CloseableList<E extends AutoCloseable> extends AbstractList<E> implements AutoCloseable {
public static final Logger LOGGER = LoggerFactory.getLogger(CloseableList.class);
private final List<E> closableList;
/**
* Create list with closable objects.
*
* @param closableList
* list of tensor, must not be null
*/
public CloseableList(List<E> closableList) {
Objects.requireNonNull(closableList);
this.closableList = closableList;
}
@Override
public E get(int index) {
return closableList.get(index);
}
@Override
public int size() {
return closableList.size();
}
@Override
public void close() {
for (E item : closableList) {
try {
item.close();
} catch (Exception e) {
LOGGER.warn("Cannot close item in list", e);
}
}
}
}
import java.util.HashMap;
import java.util.Map;
public class SertisCharacterNumberMapping {
/**
* This CHARACTERS_LABEL Mapped from Sertis word segmentation.
* <p>
* https://github.com/sertiscorp/thai-word-segmentation/blob/master/thainlplib/labeller.py
*/
private static final char[] CHARACTERS_LABEL = new char[] {'\0', '\n', ' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*',
'+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A',
'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'ก', 'ข', 'ฃ', 'ค', 'ฅ', 'ฆ', 'ง', 'จ',
'ฉ', 'ช', 'ซ', 'ฌ', 'ญ', 'ฎ', 'ฏ', 'ฐ', 'ฑ', 'ฒ', 'ณ', 'ด', 'ต', 'ถ', 'ท', 'ธ', 'น', 'บ', 'ป', 'ผ', 'ฝ', 'พ', 'ฟ',
'ภ', 'ม', 'ย', 'ร', 'ฤ', 'ล', 'ฦ', 'ว', 'ศ', 'ษ', 'ส', 'ห', 'ฬ', 'อ', 'ฮ', 'ฯ', 'ะ', 'ั', 'า', 'ำ', 'ิ', 'ี', 'ึ',
'ื', 'ุ', 'ู', '฿', 'เ', 'แ', 'โ', 'ใ', 'ไ', 'ๅ', 'ๆ', '็', '่', '้', '๊', '๋', '์', '๐', '๑', '๒', '๓', '๔', '๕',
'๖', '๗', '๘', '๙'};
private static final Map<Character, Long> INVERTED_CHARACTERS_LABEL = new HashMap<>();
static {
for (int i = 0; i < CHARACTERS_LABEL.length; i++) {
INVERTED_CHARACTERS_LABEL.put(CHARACTERS_LABEL[i], (long) i);
}
}
/**
* Gets char num.
*
* @param c
* char
* @return num
*/
public static long getCharacterLabel(char c) {
Long num = INVERTED_CHARACTERS_LABEL.get(c);
if (num != null) {
return num;
}
return 0;
}
/**
* Convert character array into character number.
*
* @param chars
* array of char
* @return array of character number
*/
public static long[] mapCharsToNumbers(char[] chars) {
int inputLength = chars.length;
long[] inputQ = new long[inputLength];
for (int i = 0; i < inputLength; i++) {
inputQ[i] = getCharacterLabel(chars[i]);
}
return inputQ;
}
}
import java.util.Map;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import com.google.protobuf.InvalidProtocolBufferException;
public class TensorflowModel {
public static final String TENSORFLOW_DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default";
private final SavedModelBundle modelBundle;
private final Session session;
private SignatureDef signatureDef;
/**
* Create TensorFlow model from saved model bundle.
*
* @param modelBundle
* model bundle
*/
public TensorflowModel(SavedModelBundle modelBundle) {
this.modelBundle = modelBundle;
this.session = modelBundle.session();
}
/**
* Load model from absolute file system path with default "serve" tag.
*
* @param fileSystemPath
* a absolute file system path
* @return tensor flow model object
*/
public static TensorflowModel loadFromPath(String fileSystemPath) {
SavedModelBundle savedModelBundle = SavedModelBundle.load(fileSystemPath, "serve");
return new TensorflowModel(savedModelBundle);
}
/**
* SignatureDef allow you to fetch generated tensor name with predefined
* mapping.
* <p>
* You can see protobuf <a href=
* "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto">meta_graph.proto</a>.
*
* @see #outputTensorName(String)
* @see #inputTensorName(String)
*
* @return signature definition
*/
public SignatureDef getDefaultServingSignatureDef() {
try {
if (signatureDef == null) {
MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(modelBundle.metaGraphDef());
Map<String, SignatureDef> signatureDefMap = metaGraphDef.getSignatureDefMap();
signatureDef = signatureDefMap.get(TENSORFLOW_DEFAULT_SERVING_SIGNATURE_DEF_KEY);
}
return signatureDef;
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Cannot parse tensorflow model for MetaGraphDef", e);
}
}
/**
* Get real output tensor name from predefined mapping name.
*
* @param name
* predefined output mapping name
* @return real tensor name
*/
public String outputTensorName(String name) {
return getDefaultServingSignatureDef().getOutputsMap().get(name).getName();
}
/**
* Get real input tensor name from predefined mapping name.
*
* @param name
* predefined input mapping name
* @return real tensor name
*/
public String inputTensorName(String name) {
return getDefaultServingSignatureDef().getInputsMap().get(name).getName();
}
/**
* Create new session, you need to close session after used manually.
*
* @return tensorflow session
*/
public Session createNewSession() {
return modelBundle.session();
}
/**
* Get session which guarantee to always be the same instance.
*
* @return tensorflow session
*/
public Session getSingletonSession() {
return session;
}
}
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;
import org.tensorflow.Tensor;
public class Tokenizer {
private static TensorflowModel tensorflowModel = TensorflowModel.loadFromPath("/Users/wongnai/tensorflow/model");
/**
* Util class.
*/
private Tokenizer() {
}
/**
* Tokenize.
*
* @param input
* input.
* @return tokenized string
*/
public static String[] tokenize(String input) {
Objects.requireNonNull(input, "input must not be null");
if (StringUtils.isBlank(input)) {
return new String[0];
}
char[] chars = StringUtils.normalizeSpace(input).toCharArray();
long[] beginningOfWordFlag = calculateBeginOfWordFlag(chars);
return breakTokens(chars, beginningOfWordFlag);
}
private static long[] calculateBeginOfWordFlag(char[] chars) {
long[] input = SertisCharacterNumberMapping.mapCharsToNumbers(chars);
try (Tensor<Long> inputTensor = Tensor.create(new long[] {1, input.length}, LongBuffer.wrap(input))) {
try (Tensor<Long> lengthTensor = Tensor.create(new long[] {1}, LongBuffer.wrap(new long[] {input.length}))) {
try (Tensor<Boolean> trainingTensor = Tensor.create(false, Boolean.class)) {
try (CloseableList<Tensor<?>> results = runModel(inputTensor, lengthTensor, trainingTensor)) {
if (results.isEmpty()) {
throw new IllegalStateException("Cannot tokenize string with tensorflow, got invalid result");
}
long[] beginningOfWordFlag = new long[input.length];
results.get(0).copyTo(beginningOfWordFlag);
return beginningOfWordFlag;
}
}
}
}
}
private static CloseableList<Tensor<?>> runModel(Tensor<Long> inputTensor, Tensor<Long> lengthTensor,
Tensor<Boolean> trainingTensor) {
List<Tensor<?>> results
= tensorflowModel.getSingletonSession().runner().feed(tensorflowModel.inputTensorName("inputs"), inputTensor)
.feed(tensorflowModel.inputTensorName("lengths"), lengthTensor)
.feed(tensorflowModel.inputTensorName("training"), trainingTensor)
.fetch(tensorflowModel.outputTensorName("outputs")).run();
return new CloseableList<>(results);
}
private static String[] breakTokens(char[] chars, long[] beginningOfWordFlag) {
char[][] chunks = TokenizeUtils.chunkByFlag(chars, beginningOfWordFlag);
List<String> result = new ArrayList<>();
for (int i = 0; i < chunks.length; i++) {
String str = String.valueOf(chunks[i]);
if (StringUtils.isNotBlank(str)) {
result.add(str);
}
}
return result.toArray(new String[0]);
}
}
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import com.google.common.base.Preconditions;
public class TokenizeUtils {
private TokenizeUtils() {
}
/**
* Chunk char array by flags.
* <p>
* When flags is non-zero, it would be chunked.
*
* @param chars
* chars
* @param flags
* flags
* @return 2d char array chunked
*/
public static char[][] chunkByFlag(char[] chars, long[] flags) {
Preconditions.checkArgument(chars.length == flags.length, "flags length must be equals to chars length");
int[] positions = flagsToPositions(flags);
char[][] chars2d = new char[positions.length + 2][];
//edge case
chars2d[0] = ArrayUtils.subarray(chars, 0, positions[0]);
chars2d[positions.length] = ArrayUtils.subarray(chars, positions[positions.length - 1], chars.length);
for (int i = 1; i < positions.length; i++) {
int begin = positions[i - 1];
int end = positions[i];
chars2d[i] = ArrayUtils.subarray(chars, begin, end);
}
return Arrays.stream(chars2d).filter(ArrayUtils::isNotEmpty).toArray(char[][]::new);
}
private static int[] flagsToPositions(long[] flags) {
List<Integer> positions = new ArrayList<>();
for (int i = 0; i < flags.length; i++) {
if (flags[i] != 0L) {
positions.add(i);
}
}
return ArrayUtils.toPrimitive(positions.toArray(new Integer[0]));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment