This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Required packages | |
import random | |
import spacy | |
import pandas as pd | |
import seaborn as sns | |
from spacy.util import minibatch | |
from sklearn.metrics import accuracy_score | |
from sklearn.metrics import confusion_matrix | |
from matplotlib import pyplot as plt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######## Main method ######## | |
def main(): | |
# Load dataset | |
data = pd.read_csv(data_path) | |
observations = len(data.index) | |
print("Dataset Size: {}".format(observations)) | |
print(data['label'].value_counts()) | |
print(data['label'].value_counts() / len(data.index) * 100.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######## Main method ######## | |
def main(): | |
# Load dataset | |
data = pd.read_csv(data_path) | |
observations = len(data.index) | |
print("Dataset Size: {}".format(observations)) | |
# Create an empty spacy model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Split data into train and test datasets | |
x_train, x_test, y_train, y_test = train_test_split( | |
data['text'], data['label'], test_size=0.33, random_state=7) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create the train and test data for the spacy model | |
train_lables = [{'cats': {'ham': label == 'ham', | |
'spam': label == 'spam'}} for label in y_train] | |
test_lables = [{'cats': {'ham': label == 'ham', | |
'spam': label == 'spam'}} for label in y_test] | |
# Spacy model data | |
train_data = list(zip(x_train, train_lables)) | |
test_data = list(zip(x_test, test_lables)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_model(model, train_data, optimizer, batch_size, epochs=10): | |
losses = {} | |
random.seed(1) | |
for epoch in range(epochs): | |
random.shuffle(train_data) | |
batches = minibatch(train_data, size=batch_size) | |
for batch in batches: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Model configurations | |
optimizer = nlp.begin_training() | |
batch_size = 5 | |
epochs = 10 | |
# Training the model | |
train_model(nlp, train_data, optimizer, batch_size, epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Sample predictions | |
print(train_data[0]) | |
sample_test = nlp(train_data[0][0]) | |
print(sample_test.cats) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_predictions(model, texts): | |
# Use the model's tokenizer to tokenize each input text | |
docs = [model.tokenizer(text) for text in texts] | |
# Use textcat to get the scores for each doc | |
textcat = model.get_pipe('textcat') | |
scores, _ = textcat.predict(docs) | |
# From the scores, find the label with the highest score/probability |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train and test accuracy | |
train_predictions = get_predictions(nlp, x_train) | |
test_predictions = get_predictions(nlp, x_test) | |
train_accuracy = accuracy_score(y_train, train_predictions) | |
test_accuracy = accuracy_score(y_test, test_predictions) | |
print("Train accuracy: {}".format(train_accuracy)) | |
print("Test accuracy: {}".format(test_accuracy)) | |
OlderNewer