Skip to content

Instantly share code, notes, and snippets.

@littlewine
Created November 8, 2023 09:29
Show Gist options
  • Save littlewine/bf0723b3710433104afa6a5bf09f11d1 to your computer and use it in GitHub Desktop.
Save littlewine/bf0723b3710433104afa6a5bf09f11d1 to your computer and use it in GitHub Desktop.
import logging
import time
import argparse
import pandas as pd
import sklearn
import torch
from tqdm.auto import tqdm
from simpletransformers.retrieval import RetrievalModel, RetrievalArgs
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
train_data_path = "../data/nq-train.tsv"
eval_data_path = "../data_beir/data_scifact"
if train_data_path.endswith(".tsv"):
train_data = pd.read_csv(train_data_path, sep="\t")
else:
train_data = train_data_path
model_args = RetrievalArgs()
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.use_cached_eval_features = False
model_args.include_title = False if "msmarco" in train_data_path else True
model_args.max_seq_length = 256
model_args.num_train_epochs = 1
model_args.train_batch_size = 16
model_args.eval_batch_size = 300
model_args.use_hf_datasets = True
model_args.learning_rate = 1e-6
model_args.warmup_steps = 5000
model_args.save_steps = -1
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 1000
model_args.evaluate_during_training_verbose = True
model_args.save_model_every_epoch = False
model_args.save_eval_checkpoints = False
model_args.save_best_model = True
model_args.early_stopping_metric = "recip_rank"
model_args.early_stopping_metric_minimize = False
model_args.evaluate_each_epoch = False
model_args.wandb_project = "IR2 Demo"
model_args.hard_negatives_in_eval = False
model_args.hard_negatives = False
model_args.n_gpu = 1
model_args.evaluate_with_beir = False
model_args.data_format = "beir"
model_args.wandb_kwargs = {"name": f"repro-dpr-epochs-{model_args.num_train_epochs}-batch_size-{model_args.train_batch_size}"}
model_args.output_dir = (
f"../models/dpr-epochs-{model_args.num_train_epochs}-batch_size-{model_args.train_batch_size}"
)
model_args.best_model_dir = model_args.output_dir + "/best_model"
model_type = "custom"
model_name = None
context_name = "bert-base-cased"
question_name = "bert-base-cased"
if __name__ == "__main__":
from multiprocess import set_start_method
set_start_method("spawn")
# Create a TransformerModel
model = RetrievalModel(
model_type,
model_name,
context_name,
question_name,
args=model_args,
)
model.train_model(
train_data,
# clustered_training=True,
eval_data=eval_data_path,
eval_set="test",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment