Skip to content

Instantly share code, notes, and snippets.

@asmsuechan
Created December 4, 2020 02:52
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save asmsuechan/d820b89f6ed546942b63f0cc4e630e29 to your computer and use it in GitHub Desktop.
Save asmsuechan/d820b89f6ed546942b63f0cc4e630e29 to your computer and use it in GitHub Desktop.
SEED = 2020
BASE_PATH = './'
TEXT_COL = "description"
TARGET = "jobflag"
NUM_CLASS = 4
N_FOLDS = 5
import os, gc, sys
import random
import uuid
import pdb
import pandas as pd
import numpy as np
from scipy.special import softmax
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.utils import shuffle
from simpletransformers.classification import ClassificationModel
import torch
from sklearn.model_selection import train_test_split
def metric_f1(labels, preds):
return f1_score(labels, preds, average='macro')
def seed_everything(seed):
"""for reproducibility.
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
seed_everything(SEED)
test = pd.read_csv(BASE_PATH+"test.csv")
test = test.rename(columns={TEXT_COL:'text'}).drop(['id'], axis=1)
train = pd.read_csv(BASE_PATH+"train-all.csv")
train = train.rename(columns={TARGET:'label', TEXT_COL:'text'})
train['label'] -= 1
kfold = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
train['fold_id'] = -1
for fold, (train_idx, valid_idx) in enumerate(kfold.split(train.index, train['label'])):
train.loc[train.iloc[valid_idx].index, 'fold_id'] = fold
X_train = train.loc[train['fold_id']!=0]
X_valid = train.loc[train['fold_id']==0]
X_valid = X_valid[X_valid.columns[::-1]]
# X_train=train
augmented = pd.DataFrame()
for i, x in X_train.iterrows():
en_desc = {'label': x.label, 'text': x['text']}
fr_en = {'label': x.label, 'text': x['fr-en-description']}
de_en = {'label': x.label, 'text': x['de-en-description']}
es_en = {'label': x.label, 'text': x['es-en-description']}
nl_en = {'label': x.label, 'text': x['nl-en-description']}
no_en = {'label': x.label, 'text': x['no-en-description']}
augmented = pd.concat([augmented, pd.DataFrame.from_dict(en_desc, orient='index').T])
#augmented = pd.concat([augmented, pd.DataFrame.from_dict(fr_en, orient='index').T])
#augmented = pd.concat([augmented, pd.DataFrame.from_dict(de_en, orient='index').T])
#augmented = pd.concat([augmented, pd.DataFrame.from_dict(es_en, orient='index').T])
#augmented = pd.concat([augmented, pd.DataFrame.from_dict(nl_en, orient='index').T])
#augmented = pd.concat([augmented, pd.DataFrame.from_dict(no_en, orient='index').T])
def pseudo_labeling(model, train, test, rate=0.2):
num_of_samples = int(len(test) * rate)
model.train_model(train)
#model = ClassificationModel('xlmroberta', 'outputs-e1e8cb02-0f8d-4687-aefc-581d96e70e80/', args={})
pseudo_labels, raw_outputs = model.predict(test['text'])
probabilities = softmax(raw_outputs, axis=1)
threshold = 0.95
exceeded_over_threshold = np.array([*map(lambda x: max(x), probabilities)]) > threshold # 閾値を超えたかどうかTrue/Falseのリスト
percentage = sum(exceeded_over_threshold) / len(probabilities) # 閾値を超えたものの割合
print('The percentage of predictions that the probability is over ' + str(threshold * 100) + '%: ', percentage)
pl_indices = np.where(exceeded_over_threshold == True)
pl_test = pd.DataFrame()
for i, x in test.iterrows():
if i in pl_indices[0]:
pl_test = pl_test.append({'label': pseudo_labels[i], 'text': x['text']}, ignore_index=True)
print(len(pl_test[pl_test.label == 0]), len(pl_test[pl_test.label == 1]), len(pl_test[pl_test.label == 2]), len(pl_test[pl_test.label == 3]))
augemented_train = pd.concat([pl_test, train])
augemented_train=augemented_train[augemented_train.columns[::-1]]
return shuffle(augemented_train), percentage
X_train = augmented
X_train = X_train[X_train.columns[::-1]]
weight = 1/ pd.DataFrame((X_train.label).tolist()).reset_index().groupby(0).count().values
weight_sum = weight.sum()
weight /= weight_sum
# "output_dir": "outputs-" + str(uuid.uuid4()) + "/",
params = {
"output_dir": "outputs-1/",
"max_seq_length": 256,
"train_batch_size": 16,
"eval_batch_size": 16,
"num_train_epochs": 8,
"learning_rate": 5e-5,
"manual_seed":SEED,
'overwrite_output_dir': True
}
# 学習済みモデルをk_foldバリデーション
# 現在未使用
def k_fold(model, X_train):
kfold = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
accuracies = []
for iteration, (train_index, val_index) in enumerate(kfold.split(np.array(X_train['text'].to_list()), np.array(X_train['label'].to_list()))): # 検証データをずらして 5 回繰り返される
X_train_fold, X_val_fold = X_train['text'].iloc[train_index], X_train['text'].iloc[val_index]
y_train_fold, y_val_fold = X_train['label'].iloc[train_index], X_train['label'].iloc[val_index]
fold_train = pd.concat([X_train_fold, y_train_fold], axis=1)
fold_valid = pd.concat([X_val_fold, y_val_fold], axis=1)
model.train_model(fold_train)
result, model_outputs, wrong_predictions = model.eval_model(fold_valid, f1=metric_f1)
print(result)
accuracies.append(result['f1'])
acc_mean = np.mean(accuracies)
return acc_mean
model = ClassificationModel('xlmroberta', 'xlm-roberta-base', num_labels=4, weight=weight.tolist(), args=params, use_cuda=True)
percentage = 0
while percentage < 0.98:
X_train, percentage = pseudo_labeling(model, X_train, test)
model.train_model(X_train)
# result, model_outputs, wrong_predictions = model.eval_model(X_valid, f1=metric_f1)
# print('validation score is ', result)
y_pred_valid, raw_outputs_valid = model.predict(X_valid['text'])
valid_df = pd.DataFrame({'id': X_valid.index, 'label': np.array(y_pred_valid)+1})
valid_df.to_csv(BASE_PATH+"validation_result.csv", index=False)
y_pred, raw_outputs = model.predict(test['text'])
print(y_pred)
test = pd.read_csv(BASE_PATH+"test.csv")
submit = pd.DataFrame({'index':test['id'], 'pred': np.array(y_pred)+1})
submit.to_csv(BASE_PATH+"submit_model2_bert-2.csv", index=False, header=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment