This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
seqs = [create_seq(i) for i in movie_plots] | |
# merge list-of-lists into a single list | |
seqs = sum(seqs, []) | |
# count of sequences | |
len(seqs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create sequences of length 5 tokens | |
def create_seq(text, seq_len = 5): | |
sequences = [] | |
# if the number of tokens in 'text' is greater than 5 | |
if len(text.split()) > seq_len: | |
for i in range(seq_len, len(text.split())): | |
# select sequence of tokens | |
seq = text.split()[i-seq_len:i+1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# read pickle file | |
pickle_in = open("plots_text.pickle","rb") | |
movie_plots = pickle.load(pickle_in) | |
# count of movie plot summaries | |
len(movie_plots) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import pickle | |
import random | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
preds = np.argmax(preds, axis = 1) | |
print(classification_report(test_y, preds)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# get predictions for test data | |
with torch.no_grad(): | |
preds = model(test_seq.to(device), test_mask.to(device)) | |
preds = preds.detach().cpu().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#load weights of best model | |
path = 'saved_weights.pt' | |
model.load_state_dict(torch.load(path)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# set initial loss to infinite | |
best_valid_loss = float('inf') | |
# empty lists to store training and validation loss of each epoch | |
train_losses=[] | |
valid_losses=[] | |
#for each epoch | |
for epoch in range(epochs): | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# function for evaluating the model | |
def evaluate(): | |
print("\nEvaluating...") | |
# deactivate dropout layers | |
model.eval() | |
total_loss, total_accuracy = 0, 0 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# function to train the model | |
def train(): | |
model.train() | |
total_loss, total_accuracy = 0, 0 | |
# empty list to save model predictions | |
total_preds=[] | |