Skip to content

Instantly share code, notes, and snippets.

@OlavHN
Last active November 17, 2017 11:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save OlavHN/5c5607e5a3e1cb405001b139f4dbf90d to your computer and use it in GitHub Desktop.
Save OlavHN/5c5607e5a3e1cb405001b139f4dbf90d to your computer and use it in GitHub Desktop.
Example of hierarchical categories with ids managed by tensorflow
import tensorflow as tf
def example():
leafs = tf.contrib.lookup.index_table_from_file(
vocabulary_file='leafs.txt')
parents = tf.contrib.lookup.index_table_from_file(
vocabulary_file='parents.txt')
leafs2parents = tf.contrib.lookup.index_to_string_table_from_file(
vocabulary_file='leaf2parents.txt')
leaf_category = tf.placeholder(tf.string, shape=[None])
leaf_ids = leafs.lookup(leaf_category)
parent_names = leafs2parents.lookup(leaf_ids)
parent_ids = parents.lookup(parent_names)
# "predictions" is a matrix of size [batch x num_leafs]
predictions = tf.random_normal(
shape=[tf.size(leaf_category),
tf.to_int32(leafs.size())])
mapping = parents.lookup(leafs2parents.lookup(tf.range(leafs.size())))
# matrix of size [batch x num_parents]
parent_predictions = tf.matmul(predictions,
tf.one_hot(mapping,
tf.to_int32(parents.size())))
with tf.Session() as sess:
sess.run(tf.tables_initializer())
res = sess.run(
{
"leaf_id": leaf_ids,
"leaf_name": leaf_category,
"parent_id": parent_ids,
"parent_name": parent_names,
"leaf_predictions": predictions,
"parent_predictions": parent_predictions
},
feed_dict={
leaf_category: ['leaf1', 'leaf0', 'leaf8']
})
print(res)
def create_files():
# We have 5 parent categories and two subcategories per category
# We create three files:
# Mapping between ids and leafs
with tf.gfile.Open('leafs.txt', 'w') as f:
for i in range(10):
f.write('leaf{}\n'.format(i))
# Mapping between ids and parents
with tf.gfile.Open('parents.txt', 'w') as f:
for i in range(5):
f.write('parent{}\n'.format(i))
# Mapping between leaf ids and parent names
with tf.gfile.Open('leaf2parents.txt', 'w') as f:
for i in range(10):
f.write('parent{}\n'.format(i % 5))
if __name__ == "__main__":
create_files()
example()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment