Skip to content

Instantly share code, notes, and snippets.

View piyushghai's full-sized avatar
🎯
Focusing

Piyush Ghai piyushghai

🎯
Focusing
  • San Francisco Bay Area, California
View GitHub Profile
#
# ResNet-50 model training using Keras and Horovod.
#
import argparse
from tensorflow import keras
from tensorflow.python.keras import backend as K
from tensorflow.keras.preprocessing import image
import tensorflow as tf
import horovod.tensorflow.keras as hvd
@piyushghai
piyushghai / validate_tf_records.py
Last active October 7, 2019 21:32
validate_tf_records
import tensorflow as tf
import argparse
import os
from tqdm import tqdm
def validate_record(file_path):
print ('Validating record {}'.format(file_path))
count = 0
for example in tf.python_io.tf_record_iterator(file_path):
ex = tf.train.Example.FromString(example)
@piyushghai
piyushghai / hdf5_to_tfrecord_converter.py
Created September 17, 2019 21:06
hdf5_to_tfrecord_converter
import argparse
import collections
import os
import deepdish as dd
import numpy as np
import tensorflow as tf
from tqdm import tqdm
parser = argparse.ArgumentParser()
package sample
import java.io.File
import scala.io.Source
import javax.imageio.ImageIO
import org.apache.mxnet._
import org.apache.mxnet.infer.{ImageClassifier, Predictor}
object PredictorSample {
package mxnet;
import org.apache.mxnet.infer.javaapi.Predictor;
import org.apache.mxnet.javaapi.*;
import org.apache.mxnet.javaapi.Shape;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import mxnet as mx
from mxnet import ndarray as nd
import numpy as np
from collections import namedtuple
import math
Batch = namedtuple('Batch', ['data'])
ctx = mx.gpu()
use_batch=True
num_runs=1000
@piyushghai
piyushghai / ImageClassifier.java
Last active November 13, 2018 01:43
ImageClassification in Java using Predictor API
package mxnet;
import org.apache.mxnet.infer.javaapi.Predictor;
import org.apache.mxnet.javaapi.Context;
import org.apache.mxnet.javaapi.DType;
import org.apache.mxnet.javaapi.DataDesc;
import org.apache.mxnet.javaapi.Shape;
import javax.imageio.ImageIO;
import java.awt.*;
@piyushghai
piyushghai / pre_process_input.java
Created November 12, 2018 23:53
pre_process_input
private static BufferedImage loadIamgeFromFile(String inputImagePath) throws IOException {
return ImageIO.read(new File(inputImagePath));
}
private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
g.drawImage(buf, 0, 0, newWidth, newHeight, null);
g.dispose();
@piyushghai
piyushghai / model-symbol.json
Created October 10, 2018 22:59
LSTM Unroll Name conflict issue
{
"nodes": [
{
"op": "null",
"name": "data",
"inputs": []
},
{
"op": "SliceChannel",
"name": "myblock0_split0",
@piyushghai
piyushghai / remitly_bot.py
Last active October 10, 2018 06:39
Bot to fetch latest USD rates from Remitly Inc.
#!/usr/bin/env python
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# http://www.apache.org/licenses/LICENSE-2.0
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.