Skip to content

Instantly share code, notes, and snippets.

@dchaplinsky
Created April 18, 2023 15:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dchaplinsky/265bbd702caef219423d073c5065b46d to your computer and use it in GitHub Desktop.
Save dchaplinsky/265bbd702caef219423d073c5065b46d to your computer and use it in GitHub Desktop.
import os
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
# Load pre-trained model for sentence embeddings
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
# Set up LSTM model
input_size = 768 # Size of the sentence embeddings
hidden_size = 2048 # Number of hidden units in the LSTM
num_layers = 1 # Number of layers in the LSTM
batch_size = 32 # Batch size for training the LSTM
learning_rate = 0.001 # Initial learning rate for the optimizer
accumulation_steps = 4 # Number of batches to accumulate gradients over
# Define the LSTM model
lstm_model = torch.nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers
)
# Set up the optimizer and loss function for training the LSTM
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=10, verbose=True
)
criterion = torch.nn.CrossEntropyLoss()
# Load data from disk
data_path = "/path/to/data/folder"
filenames = os.listdir(data_path)
# Train the LSTM model
for filename in filenames:
# Load the file as a list of sentence strings
with open(os.path.join(data_path, filename), "r", encoding="utf-8") as f:
sentences = f.read().splitlines()
# Calculate sentence embeddings using the pre-trained model
embeddings = model.encode(sentences)
# Convert sentence embeddings to PyTorch tensors
embeddings_tensor = torch.from_numpy(np.array(embeddings))
# Reshape the embeddings tensor to match the expected input shape of the LSTM
embeddings_tensor = embeddings_tensor.view(len(sentences), 1, -1)
# Train the LSTM model
total_loss = 0
for i in range(0, len(sentences) - batch_size, batch_size):
# Get the batch of sentence embeddings and corresponding targets
batch_embeddings = embeddings_tensor[i : i + batch_size]
targets = torch.LongTensor(range(i + 1, i + batch_size + 1))
# Zero the gradients and make a forward pass through the LSTM
lstm_model.zero_grad()
outputs, _ = lstm_model(batch_embeddings)
# Compute the loss and perform backpropagation
loss = criterion(outputs.view(batch_size, -1), targets)
loss /= accumulation_steps
loss.backward()
# Accumulate gradients over multiple batches
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
# Adjust the learning rate using the scheduler
scheduler.step(total_loss / (len(sentences) // batch_size))
# Log the loss
print(f"File: {filename}, Loss: {total_loss / (len(sentences) // batch_size)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment