Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Last active November 7, 2023 16:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save napsternxg/b457b57ac88d097e87fc022f3a470d62 to your computer and use it in GitHub Desktop.
Save napsternxg/b457b57ac88d097e87fc022f3a470d62 to your computer and use it in GitHub Desktop.
accelerate support for sentence_transformer
diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py
index e44e573..ae4dea4 100644
--- a/sentence_transformers/SentenceTransformer.py
+++ b/sentence_transformers/SentenceTransformer.py
@@ -16,6 +16,7 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from tqdm.autonotebook import trange
+from tqdm.autonotebook import tqdm
import math
import queue
import tempfile
@@ -272,7 +273,7 @@ class SentenceTransformer(nn.Sequential):
last_chunk_id = 0
chunk = []
- for sentence in sentences:
+ for sentence in tqdm(sentences, desc="Enqueue sentences"):
chunk.append(sentence)
if len(chunk) >= chunk_size:
input_queue.put([last_chunk_id, batch_size, chunk])
@@ -284,7 +285,7 @@ class SentenceTransformer(nn.Sequential):
last_chunk_id += 1
output_queue = pool['output']
- results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])
+ results_list = sorted([output_queue.get() for _ in tqdm(range(last_chunk_id), desc="Dequeue sentences")], key=lambda x: x[0])
embeddings = np.concatenate([result[1] for result in results_list])
return embeddings
@@ -591,7 +592,8 @@ class SentenceTransformer(nn.Sequential):
show_progress_bar: bool = True,
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
- checkpoint_save_total_limit: int = 0
+ checkpoint_save_total_limit: int = 0,
+ accelerator=None
):
"""
Train the model with the given training objective
@@ -637,8 +639,9 @@ class SentenceTransformer(nn.Sequential):
if use_amp:
from torch.cuda.amp import autocast
scaler = torch.cuda.amp.GradScaler()
-
- self.to(self._target_device)
+
+ if accelerator is None:
+ self.to(self._target_device)
dataloaders = [dataloader for dataloader, _ in train_objectives]
@@ -647,8 +650,9 @@ class SentenceTransformer(nn.Sequential):
dataloader.collate_fn = self.smart_batching_collate
loss_models = [loss for _, loss in train_objectives]
- for loss_model in loss_models:
- loss_model.to(self._target_device)
+ if accelerator is None:
+ for loss_model in loss_models:
+ loss_model.to(self._target_device)
self.best_score = -9999999
@@ -682,6 +686,36 @@ class SentenceTransformer(nn.Sequential):
num_train_objectives = len(train_objectives)
skip_scheduler = False
+
+ if accelerator:
+ def _accelerate_items(loss_models, optimizers, schedulers, dataloaders):
+ n_loss_models = len(loss_models)
+ n_optimizers = len(optimizers)
+ n_schedulers = len(schedulers)
+ n_dataloaders = len(dataloaders)
+ acceleratable_items = loss_models + optimizers + schedulers + dataloaders
+ logger.info(f"Before: {len(loss_models)=}, {len(optimizers)=}, {len(schedulers)=}, {len(dataloaders)=}, {len(acceleratable_items)=}")
+ acceleratable_items = accelerator.prepare(*acceleratable_items)
+ logger.info(f"Accelerated {len(acceleratable_items)=} items")
+ i = 0
+ loss_models = acceleratable_items[i:i+n_loss_models]
+ i += n_loss_models
+ optimizers = acceleratable_items[i:i+n_optimizers]
+ i += n_optimizers
+ schedulers = acceleratable_items[i:i+n_schedulers]
+ i += n_schedulers
+ dataloaders = acceleratable_items[i:i+n_dataloaders]
+ i += n_dataloaders
+ logger.info(f"After: {len(loss_models)=}, {len(optimizers)=}, {len(schedulers)=}, {len(dataloaders)=}, {len(acceleratable_items)=}")
+ return accelerator, loss_models, optimizers, schedulers, dataloaders
+ accelerator, loss_models, optimizers, schedulers, dataloaders = _accelerate_items(
+ loss_models, optimizers, schedulers, dataloaders
+ )
+ show_progress_bar = show_progress_bar and accelerator.is_local_main_process
+
+
+
+
for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar):
training_steps = 0
@@ -704,8 +738,9 @@ class SentenceTransformer(nn.Sequential):
data = next(data_iterator)
features, labels = data
- labels = labels.to(self._target_device)
- features = list(map(lambda batch: batch_to_device(batch, self._target_device), features))
+ if accelerator is None:
+ labels = labels.to(self._target_device)
+ features = list(map(lambda batch: batch_to_device(batch, self._target_device), features))
if use_amp:
with autocast():
@@ -721,7 +756,10 @@ class SentenceTransformer(nn.Sequential):
skip_scheduler = scaler.get_scale() != scale_before_step
else:
loss_value = loss_model(features, labels)
- loss_value.backward()
+ if accelerator:
+ accelerator.backward(loss_value)
+ else:
+ loss_value.backward()
torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
optimizer.step()
@@ -741,16 +779,19 @@ class SentenceTransformer(nn.Sequential):
loss_model.train()
if checkpoint_path is not None and checkpoint_save_steps is not None and checkpoint_save_steps > 0 and global_step % checkpoint_save_steps == 0:
- self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)
+ if accelerator is None or accelerator.is_local_main_process:
+ self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)
self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)
if evaluator is None and output_path is not None: #No evaluator, but output path: save final model version
- self.save(output_path)
+ if accelerator is None or accelerator.is_local_main_process:
+ self.save(output_path)
if checkpoint_path is not None:
- self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)
+ if accelerator is None or accelerator.is_local_main_process:
+ self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)

Usage is as follows

from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = SentenceTransformer(model_type)
train_loss = losses.CosineSimilarityLoss(model)
model.fit(
    train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100,
    show_progress_bar=True,
    checkpoint_path=checkpoint_path,
    checkpoint_save_steps=50_000,
    checkpoint_save_total_limit=1,
    accelerator=accelerator
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment