This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from torch.utils.data import IterableDataset, DataLoader | |
class DistributedIterableDataset(IterableDataset): | |
""" | |
Example implementation of an IterableDataset that handles both multiprocessing (num_workers > 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import contextlib | |
import tensorflow as tf | |
from tensorflow.python.framework import ops | |
from tensorflow.contrib.distribute.python import collective_all_reduce_strategy | |
from tensorflow.python.distribute import multi_worker_test_base | |
from tensorflow.python.training import coordinator | |
from tensorflow.python.training import server_lib | |
from tensorflow.python.eager import context |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
"""Save a tensorflow model to a pb file.""" | |
# Build the model, then train it or load weights from somewhere else. | |
# ... | |
graph = tf.get_default_graph() | |
input_graph_def = graph.as_graph_def() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
freeze tensorflow checkpoint file to pb format. | |
""" | |
import argparse | |
import os | |
import tensorflow as tf | |
import logging | |
from tensorflow.python.framework import graph_util |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sudo su | |
# Java | |
yum -y install java-1.8.0-openjdk-devel | |
# Build Esentials (minimal) | |
yum -y install gcc gcc-c++ kernel-devel make automake autoconf swig git unzip libtool binutils | |
# Extra Packages for Enterprise Linux (EPEL) (for pip, zeromq3) | |
yum -y install epel-release |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/bash | |
# Download zeromq | |
# Ref http://zeromq.org/intro:get-the-software | |
wget https://github.com/zeromq/libzmq/releases/download/v4.2.2/zeromq-4.2.2.tar.gz | |
# Unpack tarball package | |
tar xvzf zeromq-4.2.2.tar.gz | |
# Install dependency |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tensorflow estimator API example | |
References: | |
- <https://www.tensorflow.org/guide/custom_estimators> | |
- <https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py> | |
- <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md> | |
""" | |
import numpy as np | |
import tensorflow as tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Reference: https://stackoverflow.com/questions/43547402/how-to-calculate-f1-macro-in-keras""" | |
from keras import backend as K | |
def f1(y_true, y_pred): | |
def recall(y_true, y_pred): | |
"""Recall metric. | |
Only computes a batch-wise average of recall. | |
Computes the recall, a metric for multi-label classification of |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TopQueue: | |
"""keep top n num in a list with ascend order.""" | |
def __init__(self, n): | |
self.n = n | |
self.queue = [] | |
def add(self, new_val): | |
bigger_index = -1 | |
for i, one in enumerate(self.queue): | |
if one > new_val: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def masked_softmax(logits, mask, axis): | |
"""softmax with mask after the exp operation""" | |
e_logits = tf.exp(logits) | |
masked_e = tf.multiply(e_logits, mask) | |
sum_masked_e = tf.reduce_sum(masked_e, axis, keep_dims=True) | |
# be careful when the sum is 0 | |
ones = tf.ones_like(sum_masked_e) | |
sum_masked_e_safe = tf.where(tf.equal(sum_masked_e, 0), ones, sum_masked_e) |
NewerOlder