Created
June 12, 2018 10:31
-
-
Save tanapoln/2b800d8adf14584e3f4587cb1f88491d to your computer and use it in GitHub Desktop.
Tensorflow Java API with Sertis Word Segmentation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.AbstractList; | |
import java.util.List; | |
import java.util.Objects; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
public class CloseableList<E extends AutoCloseable> extends AbstractList<E> implements AutoCloseable { | |
public static final Logger LOGGER = LoggerFactory.getLogger(CloseableList.class); | |
private final List<E> closableList; | |
/** | |
* Create list with closable objects. | |
* | |
* @param closableList | |
* list of tensor, must not be null | |
*/ | |
public CloseableList(List<E> closableList) { | |
Objects.requireNonNull(closableList); | |
this.closableList = closableList; | |
} | |
@Override | |
public E get(int index) { | |
return closableList.get(index); | |
} | |
@Override | |
public int size() { | |
return closableList.size(); | |
} | |
@Override | |
public void close() { | |
for (E item : closableList) { | |
try { | |
item.close(); | |
} catch (Exception e) { | |
LOGGER.warn("Cannot close item in list", e); | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.HashMap; | |
import java.util.Map; | |
public class SertisCharacterNumberMapping { | |
/** | |
* This CHARACTERS_LABEL Mapped from Sertis word segmentation. | |
* <p> | |
* https://github.com/sertiscorp/thai-word-segmentation/blob/master/thainlplib/labeller.py | |
*/ | |
private static final char[] CHARACTERS_LABEL = new char[] {'\0', '\n', ' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', | |
'+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', | |
'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', | |
'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', | |
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'ก', 'ข', 'ฃ', 'ค', 'ฅ', 'ฆ', 'ง', 'จ', | |
'ฉ', 'ช', 'ซ', 'ฌ', 'ญ', 'ฎ', 'ฏ', 'ฐ', 'ฑ', 'ฒ', 'ณ', 'ด', 'ต', 'ถ', 'ท', 'ธ', 'น', 'บ', 'ป', 'ผ', 'ฝ', 'พ', 'ฟ', | |
'ภ', 'ม', 'ย', 'ร', 'ฤ', 'ล', 'ฦ', 'ว', 'ศ', 'ษ', 'ส', 'ห', 'ฬ', 'อ', 'ฮ', 'ฯ', 'ะ', 'ั', 'า', 'ำ', 'ิ', 'ี', 'ึ', | |
'ื', 'ุ', 'ู', '฿', 'เ', 'แ', 'โ', 'ใ', 'ไ', 'ๅ', 'ๆ', '็', '่', '้', '๊', '๋', '์', '๐', '๑', '๒', '๓', '๔', '๕', | |
'๖', '๗', '๘', '๙'}; | |
private static final Map<Character, Long> INVERTED_CHARACTERS_LABEL = new HashMap<>(); | |
static { | |
for (int i = 0; i < CHARACTERS_LABEL.length; i++) { | |
INVERTED_CHARACTERS_LABEL.put(CHARACTERS_LABEL[i], (long) i); | |
} | |
} | |
/** | |
* Gets char num. | |
* | |
* @param c | |
* char | |
* @return num | |
*/ | |
public static long getCharacterLabel(char c) { | |
Long num = INVERTED_CHARACTERS_LABEL.get(c); | |
if (num != null) { | |
return num; | |
} | |
return 0; | |
} | |
/** | |
* Convert character array into character number. | |
* | |
* @param chars | |
* array of char | |
* @return array of character number | |
*/ | |
public static long[] mapCharsToNumbers(char[] chars) { | |
int inputLength = chars.length; | |
long[] inputQ = new long[inputLength]; | |
for (int i = 0; i < inputLength; i++) { | |
inputQ[i] = getCharacterLabel(chars[i]); | |
} | |
return inputQ; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Map; | |
import org.tensorflow.SavedModelBundle; | |
import org.tensorflow.Session; | |
import org.tensorflow.framework.MetaGraphDef; | |
import org.tensorflow.framework.SignatureDef; | |
import com.google.protobuf.InvalidProtocolBufferException; | |
public class TensorflowModel { | |
public static final String TENSORFLOW_DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; | |
private final SavedModelBundle modelBundle; | |
private final Session session; | |
private SignatureDef signatureDef; | |
/** | |
* Create TensorFlow model from saved model bundle. | |
* | |
* @param modelBundle | |
* model bundle | |
*/ | |
public TensorflowModel(SavedModelBundle modelBundle) { | |
this.modelBundle = modelBundle; | |
this.session = modelBundle.session(); | |
} | |
/** | |
* Load model from absolute file system path with default "serve" tag. | |
* | |
* @param fileSystemPath | |
* a absolute file system path | |
* @return tensor flow model object | |
*/ | |
public static TensorflowModel loadFromPath(String fileSystemPath) { | |
SavedModelBundle savedModelBundle = SavedModelBundle.load(fileSystemPath, "serve"); | |
return new TensorflowModel(savedModelBundle); | |
} | |
/** | |
* SignatureDef allow you to fetch generated tensor name with predefined | |
* mapping. | |
* <p> | |
* You can see protobuf <a href= | |
* "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto">meta_graph.proto</a>. | |
* | |
* @see #outputTensorName(String) | |
* @see #inputTensorName(String) | |
* | |
* @return signature definition | |
*/ | |
public SignatureDef getDefaultServingSignatureDef() { | |
try { | |
if (signatureDef == null) { | |
MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(modelBundle.metaGraphDef()); | |
Map<String, SignatureDef> signatureDefMap = metaGraphDef.getSignatureDefMap(); | |
signatureDef = signatureDefMap.get(TENSORFLOW_DEFAULT_SERVING_SIGNATURE_DEF_KEY); | |
} | |
return signatureDef; | |
} catch (InvalidProtocolBufferException e) { | |
throw new RuntimeException("Cannot parse tensorflow model for MetaGraphDef", e); | |
} | |
} | |
/** | |
* Get real output tensor name from predefined mapping name. | |
* | |
* @param name | |
* predefined output mapping name | |
* @return real tensor name | |
*/ | |
public String outputTensorName(String name) { | |
return getDefaultServingSignatureDef().getOutputsMap().get(name).getName(); | |
} | |
/** | |
* Get real input tensor name from predefined mapping name. | |
* | |
* @param name | |
* predefined input mapping name | |
* @return real tensor name | |
*/ | |
public String inputTensorName(String name) { | |
return getDefaultServingSignatureDef().getInputsMap().get(name).getName(); | |
} | |
/** | |
* Create new session, you need to close session after used manually. | |
* | |
* @return tensorflow session | |
*/ | |
public Session createNewSession() { | |
return modelBundle.session(); | |
} | |
/** | |
* Get session which guarantee to always be the same instance. | |
* | |
* @return tensorflow session | |
*/ | |
public Session getSingletonSession() { | |
return session; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.nio.LongBuffer; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Objects; | |
import org.apache.commons.lang3.StringUtils; | |
import org.tensorflow.Tensor; | |
public class Tokenizer { | |
private static TensorflowModel tensorflowModel = TensorflowModel.loadFromPath("/Users/wongnai/tensorflow/model"); | |
/** | |
* Util class. | |
*/ | |
private Tokenizer() { | |
} | |
/** | |
* Tokenize. | |
* | |
* @param input | |
* input. | |
* @return tokenized string | |
*/ | |
public static String[] tokenize(String input) { | |
Objects.requireNonNull(input, "input must not be null"); | |
if (StringUtils.isBlank(input)) { | |
return new String[0]; | |
} | |
char[] chars = StringUtils.normalizeSpace(input).toCharArray(); | |
long[] beginningOfWordFlag = calculateBeginOfWordFlag(chars); | |
return breakTokens(chars, beginningOfWordFlag); | |
} | |
private static long[] calculateBeginOfWordFlag(char[] chars) { | |
long[] input = SertisCharacterNumberMapping.mapCharsToNumbers(chars); | |
try (Tensor<Long> inputTensor = Tensor.create(new long[] {1, input.length}, LongBuffer.wrap(input))) { | |
try (Tensor<Long> lengthTensor = Tensor.create(new long[] {1}, LongBuffer.wrap(new long[] {input.length}))) { | |
try (Tensor<Boolean> trainingTensor = Tensor.create(false, Boolean.class)) { | |
try (CloseableList<Tensor<?>> results = runModel(inputTensor, lengthTensor, trainingTensor)) { | |
if (results.isEmpty()) { | |
throw new IllegalStateException("Cannot tokenize string with tensorflow, got invalid result"); | |
} | |
long[] beginningOfWordFlag = new long[input.length]; | |
results.get(0).copyTo(beginningOfWordFlag); | |
return beginningOfWordFlag; | |
} | |
} | |
} | |
} | |
} | |
private static CloseableList<Tensor<?>> runModel(Tensor<Long> inputTensor, Tensor<Long> lengthTensor, | |
Tensor<Boolean> trainingTensor) { | |
List<Tensor<?>> results | |
= tensorflowModel.getSingletonSession().runner().feed(tensorflowModel.inputTensorName("inputs"), inputTensor) | |
.feed(tensorflowModel.inputTensorName("lengths"), lengthTensor) | |
.feed(tensorflowModel.inputTensorName("training"), trainingTensor) | |
.fetch(tensorflowModel.outputTensorName("outputs")).run(); | |
return new CloseableList<>(results); | |
} | |
private static String[] breakTokens(char[] chars, long[] beginningOfWordFlag) { | |
char[][] chunks = TokenizeUtils.chunkByFlag(chars, beginningOfWordFlag); | |
List<String> result = new ArrayList<>(); | |
for (int i = 0; i < chunks.length; i++) { | |
String str = String.valueOf(chunks[i]); | |
if (StringUtils.isNotBlank(str)) { | |
result.add(str); | |
} | |
} | |
return result.toArray(new String[0]); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
import org.apache.commons.lang3.ArrayUtils; | |
import com.google.common.base.Preconditions; | |
public class TokenizeUtils { | |
private TokenizeUtils() { | |
} | |
/** | |
* Chunk char array by flags. | |
* <p> | |
* When flags is non-zero, it would be chunked. | |
* | |
* @param chars | |
* chars | |
* @param flags | |
* flags | |
* @return 2d char array chunked | |
*/ | |
public static char[][] chunkByFlag(char[] chars, long[] flags) { | |
Preconditions.checkArgument(chars.length == flags.length, "flags length must be equals to chars length"); | |
int[] positions = flagsToPositions(flags); | |
char[][] chars2d = new char[positions.length + 2][]; | |
//edge case | |
chars2d[0] = ArrayUtils.subarray(chars, 0, positions[0]); | |
chars2d[positions.length] = ArrayUtils.subarray(chars, positions[positions.length - 1], chars.length); | |
for (int i = 1; i < positions.length; i++) { | |
int begin = positions[i - 1]; | |
int end = positions[i]; | |
chars2d[i] = ArrayUtils.subarray(chars, begin, end); | |
} | |
return Arrays.stream(chars2d).filter(ArrayUtils::isNotEmpty).toArray(char[][]::new); | |
} | |
private static int[] flagsToPositions(long[] flags) { | |
List<Integer> positions = new ArrayList<>(); | |
for (int i = 0; i < flags.length; i++) { | |
if (flags[i] != 0L) { | |
positions.add(i); | |
} | |
} | |
return ArrayUtils.toPrimitive(positions.toArray(new Integer[0])); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment