|
diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py |
|
index e44e573..ae4dea4 100644 |
|
--- a/sentence_transformers/SentenceTransformer.py |
|
+++ b/sentence_transformers/SentenceTransformer.py |
|
@@ -16,6 +16,7 @@ from torch.optim import Optimizer |
|
from torch.utils.data import DataLoader |
|
import torch.multiprocessing as mp |
|
from tqdm.autonotebook import trange |
|
+from tqdm.autonotebook import tqdm |
|
import math |
|
import queue |
|
import tempfile |
|
@@ -272,7 +273,7 @@ class SentenceTransformer(nn.Sequential): |
|
last_chunk_id = 0 |
|
chunk = [] |
|
|
|
- for sentence in sentences: |
|
+ for sentence in tqdm(sentences, desc="Enqueue sentences"): |
|
chunk.append(sentence) |
|
if len(chunk) >= chunk_size: |
|
input_queue.put([last_chunk_id, batch_size, chunk]) |
|
@@ -284,7 +285,7 @@ class SentenceTransformer(nn.Sequential): |
|
last_chunk_id += 1 |
|
|
|
output_queue = pool['output'] |
|
- results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0]) |
|
+ results_list = sorted([output_queue.get() for _ in tqdm(range(last_chunk_id), desc="Dequeue sentences")], key=lambda x: x[0]) |
|
embeddings = np.concatenate([result[1] for result in results_list]) |
|
return embeddings |
|
|
|
@@ -591,7 +592,8 @@ class SentenceTransformer(nn.Sequential): |
|
show_progress_bar: bool = True, |
|
checkpoint_path: str = None, |
|
checkpoint_save_steps: int = 500, |
|
- checkpoint_save_total_limit: int = 0 |
|
+ checkpoint_save_total_limit: int = 0, |
|
+ accelerator=None |
|
): |
|
""" |
|
Train the model with the given training objective |
|
@@ -637,8 +639,9 @@ class SentenceTransformer(nn.Sequential): |
|
if use_amp: |
|
from torch.cuda.amp import autocast |
|
scaler = torch.cuda.amp.GradScaler() |
|
- |
|
- self.to(self._target_device) |
|
+ |
|
+ if accelerator is None: |
|
+ self.to(self._target_device) |
|
|
|
dataloaders = [dataloader for dataloader, _ in train_objectives] |
|
|
|
@@ -647,8 +650,9 @@ class SentenceTransformer(nn.Sequential): |
|
dataloader.collate_fn = self.smart_batching_collate |
|
|
|
loss_models = [loss for _, loss in train_objectives] |
|
- for loss_model in loss_models: |
|
- loss_model.to(self._target_device) |
|
+ if accelerator is None: |
|
+ for loss_model in loss_models: |
|
+ loss_model.to(self._target_device) |
|
|
|
self.best_score = -9999999 |
|
|
|
@@ -682,6 +686,36 @@ class SentenceTransformer(nn.Sequential): |
|
num_train_objectives = len(train_objectives) |
|
|
|
skip_scheduler = False |
|
+ |
|
+ if accelerator: |
|
+ def _accelerate_items(loss_models, optimizers, schedulers, dataloaders): |
|
+ n_loss_models = len(loss_models) |
|
+ n_optimizers = len(optimizers) |
|
+ n_schedulers = len(schedulers) |
|
+ n_dataloaders = len(dataloaders) |
|
+ acceleratable_items = loss_models + optimizers + schedulers + dataloaders |
|
+ logger.info(f"Before: {len(loss_models)=}, {len(optimizers)=}, {len(schedulers)=}, {len(dataloaders)=}, {len(acceleratable_items)=}") |
|
+ acceleratable_items = accelerator.prepare(*acceleratable_items) |
|
+ logger.info(f"Accelerated {len(acceleratable_items)=} items") |
|
+ i = 0 |
|
+ loss_models = acceleratable_items[i:i+n_loss_models] |
|
+ i += n_loss_models |
|
+ optimizers = acceleratable_items[i:i+n_optimizers] |
|
+ i += n_optimizers |
|
+ schedulers = acceleratable_items[i:i+n_schedulers] |
|
+ i += n_schedulers |
|
+ dataloaders = acceleratable_items[i:i+n_dataloaders] |
|
+ i += n_dataloaders |
|
+ logger.info(f"After: {len(loss_models)=}, {len(optimizers)=}, {len(schedulers)=}, {len(dataloaders)=}, {len(acceleratable_items)=}") |
|
+ return accelerator, loss_models, optimizers, schedulers, dataloaders |
|
+ accelerator, loss_models, optimizers, schedulers, dataloaders = _accelerate_items( |
|
+ loss_models, optimizers, schedulers, dataloaders |
|
+ ) |
|
+ show_progress_bar = show_progress_bar and accelerator.is_local_main_process |
|
+ |
|
+ |
|
+ |
|
+ |
|
for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar): |
|
training_steps = 0 |
|
|
|
@@ -704,8 +738,9 @@ class SentenceTransformer(nn.Sequential): |
|
data = next(data_iterator) |
|
|
|
features, labels = data |
|
- labels = labels.to(self._target_device) |
|
- features = list(map(lambda batch: batch_to_device(batch, self._target_device), features)) |
|
+ if accelerator is None: |
|
+ labels = labels.to(self._target_device) |
|
+ features = list(map(lambda batch: batch_to_device(batch, self._target_device), features)) |
|
|
|
if use_amp: |
|
with autocast(): |
|
@@ -721,7 +756,10 @@ class SentenceTransformer(nn.Sequential): |
|
skip_scheduler = scaler.get_scale() != scale_before_step |
|
else: |
|
loss_value = loss_model(features, labels) |
|
- loss_value.backward() |
|
+ if accelerator: |
|
+ accelerator.backward(loss_value) |
|
+ else: |
|
+ loss_value.backward() |
|
torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm) |
|
optimizer.step() |
|
|
|
@@ -741,16 +779,19 @@ class SentenceTransformer(nn.Sequential): |
|
loss_model.train() |
|
|
|
if checkpoint_path is not None and checkpoint_save_steps is not None and checkpoint_save_steps > 0 and global_step % checkpoint_save_steps == 0: |
|
- self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) |
|
+ if accelerator is None or accelerator.is_local_main_process: |
|
+ self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) |
|
|
|
|
|
self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback) |
|
|
|
if evaluator is None and output_path is not None: #No evaluator, but output path: save final model version |
|
- self.save(output_path) |
|
+ if accelerator is None or accelerator.is_local_main_process: |
|
+ self.save(output_path) |
|
|
|
if checkpoint_path is not None: |
|
- self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) |
|
+ if accelerator is None or accelerator.is_local_main_process: |
|
+ self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) |
|
|
|
|
|
|