This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Add a new column named 'Price' | |
df['Prediction'] = df['labels'].map(label_map) | |
print(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
label_map = { | |
0: "Business", | |
1: "Entertainment", | |
2: "Politics", | |
3: "Sport", | |
4: "Tech" | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
unlabelled_predictions = [] | |
for data in unlabelled_data: | |
unlabelled_predictions.append(predict_category(data)) | |
prediction_df = pd.DataFrame({ | |
"data": unlabelled_data, | |
"labels": unlabelled_predictions, | |
}) | |
prediction_df.to_csv("model_prediction.csv", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
url_unknown = 'https://github.com/kiddojazz/Multitext-Classification/blob/master/bbc_data.csv?raw=true' | |
df_unknown = pd.read_csv(url) | |
print(df_unknown.head(5)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import confusion_matrix | |
from sklearn.metrics import classification_report | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
confusion = confusion_matrix(test_labels, y_pred) | |
plt.figure(figsize=(8, 6)) | |
sns.set(font_scale=1.2) | |
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", cbar=False, square=True, | |
xticklabels=["Business", "Entertainment", "Politics", "Sport", "Tech"], yticklabels=["Business", "Entertainment", "Politics", "Sport", "Tech"]) | |
plt.xlabel('Predicted') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict_category(text): | |
predict_input = loaded_tokenizer.encode(text, | |
truncation=True, | |
padding=True, | |
return_tensors="tf") | |
output = loaded_model(predict_input)[0] | |
prediction_value = tf.argmax(output, axis=1).numpy()[0] | |
return prediction_value | |
# - - - - - - - - - - - - - - - - - - - - - - - - - - - | |
y_pred = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
predict_input = loaded_tokenizer.encode(test_text, | |
truncation=True, | |
padding=True, | |
return_tensors="tf") | |
output = loaded_model(predict_input)[0] | |
prediction_value = tf.argmax(output, axis=1).numpy()[0] | |
# Convert numeric prediction to category label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Business = 0, Entertainment = 1, Politics = 2, Sport = 3, Tech = 4 | |
predict_input = loaded_tokenizer.encode(test_text, | |
truncation=True, | |
padding=True, | |
return_tensors="tf") | |
output = loaded_model(predict_input)[0] | |
prediction_value = tf.argmax(output, axis=1).numpy()[0] | |
prediction_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
save_directory = "Multitext_Classification_2" | |
loaded_tokenizer = DistilBertTokenizer.from_pretrained(save_directory) | |
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(save_directory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras.models import load_model | |
save_directory = "Multitext_Classification_colab" # Change this to your preferred location | |
model.save_pretrained(save_directory) | |
tokenizer.save_pretrained(save_directory) |
NewerOlder