Skip to content

Instantly share code, notes, and snippets.

Created December 9, 2019 21:18
Show Gist options
  • Save nbortolotti/e0d2d014a12edf50af3272318f9b7130 to your computer and use it in GitHub Desktop.
Save nbortolotti/e0d2d014a12edf50af3272318f9b7130 to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
import tensorflow as tf
train_ds_url = ""
test_ds_url = ""
ds_columns = ['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth', 'Plants']
species = np.array(['Setosa', 'Versicolor', 'Virginica'], dtype=np.object)
#Load data
categories = 'Plants'
train_path = tf.keras.utils.get_file(train_ds_url.split('/')[-1], train_ds_url)
test_path = tf.keras.utils.get_file(test_ds_url.split('/')[-1], test_ds_url)
train = pd.read_csv(train_path, names=ds_columns, header=0)
train_plantfeatures, train_categories = train, train.pop(categories)
test = pd.read_csv(test_path, names=ds_columns, header=0)
test_plantfeatures, test_categories = test, test.pop(categories)
y_categorical = tf.keras.utils.to_categorical(train_categories, num_classes=3)
y_categorical_test = tf.keras.utils.to_categorical(test_categories, num_classes=3)
#build dataset
#def build_dataset():
dataset =, y_categorical))
dataset = dataset.batch(32)
dataset = dataset.shuffle(1000)
dataset = dataset.repeat()
dataset_test =, y_categorical_test))
dataset_test = dataset_test.batch(32)
dataset_test = dataset_test.shuffle(1000)
dataset_test = dataset_test.repeat()
#build model
model = tf.keras.Sequential([
tf.keras.layers.Dense(16, input_dim=4),
tf.keras.layers.Dense(3, activation=tf.nn.softmax),
#train model, steps_per_epoch=32, epochs=100, verbose=1)
#eval model
loss, accuracy = model.evaluate(dataset_test, steps=32)
print("loss:%f"% (loss))
print("accuracy: %f"% (accuracy))
# predict
new_specie = np.array([7.9,3.8,6.4,2.0])
predition = np.around(model.predict(np.expand_dims(new_specie, axis=0))).astype([0]
print(model.predict(np.expand_dims(new_specie, axis=0)))
print("This species should be %s" % species[predition.astype(np.bool)][0])
model.predict(np.expand_dims(new_specie, axis=0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment