Skip to content

Instantly share code, notes, and snippets.

@selfboot
Last active January 22, 2024 09:27
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save selfboot/8a0cb6129d000a01e0e3605f829b62ea to your computer and use it in GitHub Desktop.
Save selfboot/8a0cb6129d000a01e0e3605f829b62ea to your computer and use it in GitHub Desktop.
Train a classification task, see Full Introduction https://selfboot.cn/2023/12/06/bert_nlp_classify/
#!pip install torch
#!pip install transformers
#!pip install scikit-learn
#!pip install numpy
import json
from sklearn.model_selection import train_test_split
import random
from datetime import datetime
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from sklearn.metrics import f1_score, recall_score, accuracy_score, precision_score
import torch.nn.functional as F
import sys
# 打开一个文件用于记录日志
log_file = open('lawer_train_output.log', 'w', buffering=1)
sys.stdout = log_file
sys.stderr = log_file
print("begin")
# 读取剧透数据
lawer_data = []
with open('./train_lawer.json', 'r') as f:
for line in f:
lawer_data.append(json.loads(line))
# 读取非剧透数据
not_lawer_data = []
with open('./train_notlawer.json', 'r') as f:
for line in f:
not_lawer_data.append(json.loads(line))
lawer_data = [(str(d['content']), 1) for d in lawer_data]
not_lawer_data = [(str(d['content']), 0) for d in not_lawer_data]
# 合并剧透和非剧透数据,并打乱
merged_data = lawer_data + not_lawer_data
random.shuffle(merged_data)
# 分离特征和标签
X, y = zip(*merged_data)
# 划分数据集: 80% 的数据用于训练,20% 用于验证。
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
def tokenize(content, max_length=512):
truncated_content = []
for t in content:
t = t if t is not None else ""
encoded = tokenizer.encode_plus(t,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors="pt")
truncated_content.append(encoded['input_ids'])
input_ids = torch.cat(truncated_content)
attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.type())
return {"input_ids": input_ids, "attention_mask": attention_mask}
train_encodings = tokenize(X_train)
val_encodings = tokenize(X_val)
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], torch.tensor(y_train))
val_dataset = TensorDataset(val_encodings['input_ids'], val_encodings['attention_mask'], torch.tensor(y_val))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)
# Focal Loss 是一种在处理高度不平衡的分类问题中非常有效的损失函数。
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.3, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, inputs, targets):
BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduction == 'sum':
return F_loss.sum()
elif self.reduction == 'mean':
return F_loss.mean()
class SingleInputBert(nn.Module):
def __init__(self):
super(SingleInputBert, self).__init__()
self.bert = BertModel.from_pretrained('./bert-base-chinese')
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
model = SingleInputBert()
model = nn.DataParallel(model)
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Cuda: {torch.cuda.is_available()}')
model.to(device)
criterion = FocalLoss(gamma=2, alpha=1, reduction='mean') # 使用 FocalLoss 作为损失函数
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True)
def evaluate_model(model, data_loader, device):
model.eval() # 确保模型处于评估模式
y_true = []
y_pred = []
total_loss = 0.0
total_batches = 0
with torch.no_grad():
for batch in data_loader:
input_ids, attention_mask, labels = [b.to(device) for b in batch]
logits = model(input_ids, attention_mask=attention_mask)
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
total_loss += loss.item()
total_batches += 1
preds = torch.argmax(logits, dim=-1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
average_loss = total_loss / total_batches
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
return accuracy, precision, recall, f1, average_loss
current_date = datetime.now().strftime("%Y%m%d")
save_path = f"./lawer_{current_date}.pt"
# 设定一些早停参数
best_precision = 0.0
patience = 10
no_improve = 0
print("Begin Epoch Training...")
model.train()
for epoch in range(50): # 最多训练50轮
print(f"Epoch: {epoch+1}")
for batch in train_loader:
optimizer.zero_grad()
input_ids, attention_mask, labels = [b.to(device) for b in batch]
logits = model(input_ids, attention_mask=attention_mask)
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
loss.backward()
optimizer.step()
train_accuracy, train_precision, train_recall, train_f1, train_loss = evaluate_model(model, train_loader, device)
val_accuracy, val_precision, val_recall, val_f1, val_loss = evaluate_model(model, val_loader, device)
scheduler.step(val_loss)
print(f"Epoch {epoch + 1} - Training Accuracy: {train_accuracy}, Precision: {train_precision}, Recall: {train_recall}, F1 Score: {train_f1}, Loss: {train_loss}")
print(f"Epoch {epoch + 1} - Validation Accuracy: {val_accuracy}, Precision: {val_precision}, Recall: {val_recall}, F1 Score: {val_f1}, Loss: {val_loss}")
if val_precision > best_precision:
best_precision = val_precision
torch.save(model.state_dict(), save_path)
print(f"Save model to {save_path}, precision: {val_precision}")
no_improve = 0
else:
no_improve += 1
if no_improve >= patience:
print("Early stopping due to no improvement.")
break
log_file.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment