Train a classification task, see Full Introduction
#!pip install torch
#!pip install transformers
#!pip install scikit-learn
#!pip install numpy
import json
from sklearn.model_selection import train_test_split
import random
from datetime import datetime
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from import DataLoader, TensorDataset
from torch.optim import AdamW
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score
import torch.nn.functional as F
import sys
# 打开一个文件用于记录日志
log_file = open('lawer_train_output.log', 'w', buffering=1)
sys.stdout = log_file
sys.stderr = log_file
# 读取剧透数据
lawer_data = []
with open('./train_lawer.json', 'r') as f:
for line in f:
# 读取非剧透数据
not_lawer_data = []
with open('./train_notlawer.json', 'r') as f:
for line in f:
lawer_data = [(str(d['content']), 1) for d in lawer_data]
not_lawer_data = [(str(d['content']), 0) for d in not_lawer_data]
# 合并剧透和非剧透数据,并打乱
merged_data = lawer_data + not_lawer_data
# 分离特征和标签
X, y = zip(*merged_data)
# 划分数据集: 80% 的数据用于训练,20% 用于验证。
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
def tokenize(content, max_length=512):
truncated_content = []
for t in content:
t = t if t is not None else ""
encoded = tokenizer.encode_plus(t,
input_ids =
attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.type())
return {"input_ids": input_ids, "attention_mask": attention_mask}
train_encodings = tokenize(X_train)
val_encodings = tokenize(X_val)
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], torch.tensor(y_train))
val_dataset = TensorDataset(val_encodings['input_ids'], val_encodings['attention_mask'], torch.tensor(y_val))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)
# Focal Loss 是一种在处理高度不平衡的分类问题中非常有效的损失函数。
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.3, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, inputs, targets):
BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduction == 'sum':
return F_loss.sum()
elif self.reduction == 'mean':
return F_loss.mean()
class SingleInputBert(nn.Module):
def __init__(self):
super(SingleInputBert, self).__init__()
self.bert = BertModel.from_pretrained('./bert-base-chinese')
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
model = SingleInputBert()
model = nn.DataParallel(model)
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Cuda: {torch.cuda.is_available()}')
criterion = FocalLoss(gamma=2, alpha=1, reduction='mean') # 使用 FocalLoss 作为损失函数
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True)
def evaluate_model(model, data_loader, device):
model.eval() # 确保模型处于评估模式
y_true = []
y_pred = []
total_loss = 0.0
total_batches = 0
with torch.no_grad():
for batch in data_loader:
input_ids, attention_mask, labels = [ for b in batch]
logits = model(input_ids, attention_mask=attention_mask)
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
total_loss += loss.item()
total_batches += 1
preds = torch.argmax(logits, dim=-1)
average_loss = total_loss / total_batches
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
return accuracy, precision, recall, f1, average_loss
current_date ="%Y%m%d")
save_path = f"./lawer_{current_date}.pt"
# 设定一些早停参数
best_precision = 0.0
patience = 10
no_improve = 0
print("Begin Epoch Training...")
for epoch in range(50): # 最多训练50轮
print(f"Epoch: {epoch+1}")
for batch in train_loader:
input_ids, attention_mask, labels = [ for b in batch]
logits = model(input_ids, attention_mask=attention_mask)
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
train_accuracy, train_precision, train_recall, train_f1, train_loss = evaluate_model(model, train_loader, device)
val_accuracy, val_precision, val_recall, val_f1, val_loss = evaluate_model(model, val_loader, device)
print(f"Epoch {epoch + 1} - Training Accuracy: {train_accuracy}, Precision: {train_precision}, Recall: {train_recall}, F1 Score: {train_f1}, Loss: {train_loss}")
print(f"Epoch {epoch + 1} - Validation Accuracy: {val_accuracy}, Precision: {val_precision}, Recall: {val_recall}, F1 Score: {val_f1}, Loss: {val_loss}")
if val_precision > best_precision:
best_precision = val_precision, save_path)
print(f"Save model to {save_path}, precision: {val_precision}")
no_improve = 0
no_improve += 1
if no_improve >= patience:
print("Early stopping due to no improvement.")
