This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def masked_softmax(logits, mask, axis): | |
"""softmax with mask after the exp operation""" | |
e_logits = tf.exp(logits) | |
masked_e = tf.multiply(e_logits, mask) | |
sum_masked_e = tf.reduce_sum(masked_e, axis, keep_dims=True) | |
# be careful when the sum is 0 | |
ones = tf.ones_like(sum_masked_e) | |
sum_masked_e_safe = tf.where(tf.equal(sum_masked_e, 0), ones, sum_masked_e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TopQueue: | |
"""keep top n num in a list with ascend order.""" | |
def __init__(self, n): | |
self.n = n | |
self.queue = [] | |
def add(self, new_val): | |
bigger_index = -1 | |
for i, one in enumerate(self.queue): | |
if one > new_val: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Reference: https://stackoverflow.com/questions/43547402/how-to-calculate-f1-macro-in-keras""" | |
from keras import backend as K | |
def f1(y_true, y_pred): | |
def recall(y_true, y_pred): | |
"""Recall metric. | |
Only computes a batch-wise average of recall. | |
Computes the recall, a metric for multi-label classification of |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/bash | |
# Download zeromq | |
# Ref http://zeromq.org/intro:get-the-software | |
wget https://github.com/zeromq/libzmq/releases/download/v4.2.2/zeromq-4.2.2.tar.gz | |
# Unpack tarball package | |
tar xvzf zeromq-4.2.2.tar.gz | |
# Install dependency |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sudo su | |
# Java | |
yum -y install java-1.8.0-openjdk-devel | |
# Build Esentials (minimal) | |
yum -y install gcc gcc-c++ kernel-devel make automake autoconf swig git unzip libtool binutils | |
# Extra Packages for Enterprise Linux (EPEL) (for pip, zeromq3) | |
yum -y install epel-release |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
freeze tensorflow checkpoint file to pb format. | |
""" | |
import argparse | |
import os | |
import tensorflow as tf | |
import logging | |
from tensorflow.python.framework import graph_util |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
"""Save a tensorflow model to a pb file.""" | |
# Build the model, then train it or load weights from somewhere else. | |
# ... | |
graph = tf.get_default_graph() | |
input_graph_def = graph.as_graph_def() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
import tensorflow as tf | |
def get_session(): | |
"""load a new session""" | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
return tf.Session(config=config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Refer to https://www.kaggle.com/jhoward/nb-svm-strong-linear-baseline | |
import re, string | |
re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])') | |
def tokenize(s): | |
return re_tok.sub(r' \1 ', s).split() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py | |
def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen): | |
"""for every sample, calculate probability for every possible label | |
you need to supply your RNN model and maxlen - the length of sequences it can handle | |
""" | |
data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty) | |
return rnn_model.predict(data, verbose=0) | |
def beamsearch(predict=keras_rnn_predict, |
OlderNewer