Created
September 25, 2018 05:10
-
-
Save susanli2016/b905f0be2f4b6f501f0d9b1d33835118 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import os | |
%matplotlib inline | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
from sklearn.preprocessing import LabelBinarizer, LabelEncoder | |
from sklearn.metrics import confusion_matrix | |
from tensorflow import keras | |
from keras.models import Sequential | |
from keras.layers import Dense, Activation, Dropout | |
from keras.preprocessing import text, sequence | |
from keras import utils | |
train_size = int(len(df) * .7) | |
train_posts = df['post'][:train_size] | |
train_tags = df['tags'][:train_size] | |
test_posts = df['post'][train_size:] | |
test_tags = df['tags'][train_size:] | |
max_words = 1000 | |
tokenize = text.Tokenizer(num_words=max_words, char_level=False) | |
tokenize.fit_on_texts(train_posts) # only fit on train | |
x_train = tokenize.texts_to_matrix(train_posts) | |
x_test = tokenize.texts_to_matrix(test_posts) | |
encoder = LabelEncoder() | |
encoder.fit(train_tags) | |
y_train = encoder.transform(train_tags) | |
y_test = encoder.transform(test_tags) | |
num_classes = np.max(y_train) + 1 | |
y_train = utils.to_categorical(y_train, num_classes) | |
y_test = utils.to_categorical(y_test, num_classes) | |
batch_size = 32 | |
epochs = 2 | |
# Build the model | |
model = Sequential() | |
model.add(Dense(512, input_shape=(max_words,))) | |
model.add(Activation('relu')) | |
model.add(Dropout(0.5)) | |
model.add(Dense(num_classes)) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', | |
optimizer='adam', | |
metrics=['accuracy']) | |
history = model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=epochs, | |
verbose=1, | |
validation_split=0.1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment