Skip to content

Instantly share code, notes, and snippets.

@A-baoYang
Created May 3, 2022 05:45
Show Gist options
  • Save A-baoYang/8b17a041662ed3eff5a94e0f15b669aa to your computer and use it in GitHub Desktop.
Save A-baoYang/8b17a041662ed3eff5a94e0f15b669aa to your computer and use it in GitHub Desktop.
if __name__ == "__main__":
# load data
np.random.seed(112)
PRETRAINED_MODEL_NAME = "longformerModel_full"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
df_train = read_data(path="../data/iflytek_public/train.json")
df_val = read_data(path="../data/iflytek_public/dev.json")
df_test = read_data(path="../data/iflytek_public/test.json")
print(df_train.shape, df_val.shape, df_test.shape)
for df in [df_train, df_val, df_test]:
df["sentence_length"] = df["sentence"].apply(lambda x: len(x))
max_length = max([df_train.sentence_length.max(), df_val.sentence_length.max(), df_test.sentence_length.max()])
max_length = 512 * (max_length // 512 + 1) # 須為 512 的倍數
print(max_length)
label_dict = read_data(path="../data/iflytek_public/labels.json")
label_dict = {item["label_des"]: int(item["label"]) for item in label_dict.to_dict("records")}
trainset = NewsDataset(mode="train", tokenizer=tokenizer, max_length=max_length, df=df_train, label_dict=label_dict)
testset = NewsDataset(mode="test", tokenizer=tokenizer, max_length=max_length, df=df_test, label_dict=label_dict)
print(len(trainset), len(testset))
# train
BATCH_SIZE = 4
EPOCHS = 6
LR = 1e-6
train_sampler = RandomSampler(trainset)
trainloader = DataLoader(trainset, sampler=train_sampler, batch_size=BATCH_SIZE)
NUM_LABELS = len(label_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels=NUM_LABELS)
model = nn.DataParallel(model, device_ids=[0, 1]) # use multiple GPUs
model = model.to(device)
model_train(model, trainloader)
@18410080631
Copy link

请问这个tokenizer是在哪儿找的,代码中的tokenizer报错,找不到longformerModel_full,longformer_zh中也没有给出相应的tokenizer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment