-
-
Save b00blik/f165fa1c01cdb621d537f6d031b08122 to your computer and use it in GitHub Desktop.
tf-kickstart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tflearn | |
# Функция препроцессинга | |
def preprocess(data, columns_to_ignore): | |
# Сортируем по id по убыванию и выпилиываем колонки | |
for id in sorted(columns_to_ignore, reverse=True): | |
[r.pop(id) for r in data] | |
for i in range(len(data)): | |
data[i][1] = 1. if data[i][1] == 'female' else 0. | |
return np.array(data, dtype=np.float32) | |
# Загрузим набор даных для нашего кейса | |
from tflearn.datasets import titanic | |
titanic.download_dataset('titanic_dataset.csv') | |
# Прочитаем CSV, после этого обозначим | |
# что верхние строки (заголовки таблицы) это метки | |
from tflearn.data_utils import load_csv | |
data, labels = load_csv('titanic_dataset.csv', target_column=0, | |
categorical_labels=True, n_classes=2) | |
# Игнорируем колонки 'name' и 'ticket' | |
to_ignore = [1,6] | |
# Препроцессим данные | |
data = preprocess(data, to_ignore) | |
# Строим нейросеть | |
net = tflearn.input_data(shape=[None, 6]) | |
net = tflearn.fully_connected(net, 32) | |
net = tflearn.fully_connected(net, 32) | |
net = tflearn.fully_connected(net, 2, activation='softmax') | |
net = tflearn.regression(net) | |
# Определим модель | |
model = tflearn.DNN(net) | |
# Обучим | |
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True) | |
# Создадим кастомных данных | |
dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000] | |
winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000] | |
# Запрепроцессим | |
dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore) | |
# Предскажем шансы на выживание | |
pred = model.predict([dicaprio, winslet]) | |
print("DiCaprio Surviving Rate:", pred[0][1]) | |
print("Winslet Surviving Rate:", pred[1][1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment