Last active
May 7, 2018 15:23
-
-
Save piroyoung/41f4135ac71ac8ddba4ad8b9c61cbe98 to your computer and use it in GitHub Desktop.
可変長の1階テンソルをskip-gramの2階テンソルに変換する. tf.data.Datasetと一緒にお使いください
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class VarLenSeries: | |
""" | |
sess = tf.InteractiveSession() | |
series = tf.constant(list(range(5))) | |
sess.run(VarLenSeries(series).to_skip_gram(3)) | |
# # it returns below | |
# array([[0, 1], | |
# [0, 2], | |
# [0, 3], | |
# [1, 2], | |
# [1, 3], | |
# [1, 4], | |
# [2, 3], | |
# [2, 4], | |
# [3, 4], | |
# [1, 0], | |
# [2, 1], | |
# [2, 0], | |
# [3, 2], | |
# [3, 1], | |
# [3, 0], | |
# [4, 3], | |
# [4, 2], | |
# [4, 1]], dtype=int32) | |
""" | |
def __init__(self, series: tf.Tensor): | |
self.series = series | |
@staticmethod | |
def _get_n_gram(series: tf.Tensor, window_size: int) -> tf.Tensor: | |
""" | |
# example 3-gram | |
[0, 1, 2, 3, 4, 5, 6] => [[0, 1, 2], [1, 2, 3], [2, 3, 4], ...] | |
:param series: | |
:param window_size: | |
:return: | |
""" | |
w = window_size | |
return tf.transpose( | |
tf.stack( | |
[series[i:-(w - 1) + i] if w - 1 != i else series[i:] for i in range(w)] | |
) | |
) | |
@classmethod | |
def _skip_forward(cls, series: tf.Tensor, window_size: int): | |
# n-gram | |
w = window_size + 1 | |
x = cls._get_n_gram(series, w) | |
def _forward(_x: tf.Tensor): | |
return tf.stack([tf.gather(_x, [0, i]) for i in range(1, w)]) | |
body = tf.reshape(tf.map_fn(_forward, x), [-1, 2]) | |
def _forward_tail(_x: tf.Tensor): | |
tail_list = [] | |
last_row = _x[-1, :] | |
for i in range(1, w - 1): | |
for d in range(1, w - i): | |
tail_list.append(tf.gather(last_row, [i, i + d])) | |
return tf.stack(tail_list) | |
tail = _forward_tail(x) | |
return tf.concat([body, tail], axis=0) | |
@classmethod | |
def _skip_backward(cls, series: tf.Tensor, window_size: int): | |
# n-gram | |
w = window_size + 1 | |
x = cls._get_n_gram(series, w) | |
def _backward(_x: tf.Tensor): | |
return tf.stack([tf.gather(_x, [w - 1, w - 1 - i]) for i in range(1, w)]) | |
body = tf.reshape(tf.map_fn(_backward, x), [-1, 2]) | |
def _backward_head(_x: tf.Tensor): | |
head_list = [] | |
first_row = _x[0, :] | |
for i in range(w - 1): | |
for d in range(1, i + 1): | |
head_list.append(tf.gather(first_row, [i, i - d])) | |
return tf.stack(head_list) | |
head = _backward_head(x) | |
return tf.concat([head, body], axis=0) | |
@classmethod | |
def _get_skip_gram(cls, series: tf.Tensor, window_size: int): | |
return tf.concat( | |
[ | |
cls._skip_forward(series, window_size), | |
cls._skip_backward(series, window_size) | |
], | |
axis=0 | |
) | |
def to_skip_gram(self, window_size: int) -> tf.Tensor: | |
return self._get_skip_gram(self.series, window_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
バグってた