Created
July 3, 2024 03:22
-
-
Save classicvalues/972b8e922100d19c2cec39309b8271bc to your computer and use it in GitHub Desktop.
A Trained and Fine-Tuned LLM via Publicly Available Literature
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import fitz # PyMuPDF | |
import re | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import AdamW | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def extract_text_from_pdf(pdf_path): | |
text = "" | |
with fitz.open(pdf_path) as doc: | |
for page in doc: | |
text += page.get_text() | |
return text | |
def clean_text(text): | |
text = re.sub(r'\s+', ' ', text) # Remove extra whitespace | |
text = re.sub(r'\[\d+\]', '', text) # Remove references | |
return text.strip() | |
def tokenize_texts(texts, tokenizer): | |
return [tokenizer(text, return_tensors="pt", truncation=True, padding=True) for text in texts] | |
def train(model, dataloader, optimizer, epochs=1): | |
model.train() | |
for epoch in range(epochs): | |
for batch in dataloader: | |
inputs, masks = batch | |
inputs = inputs.to(device) | |
masks = masks.to(device) | |
outputs = model(input_ids=inputs, attention_mask=masks, labels=inputs) | |
loss = outputs.loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch {epoch}, Loss: {loss.item()}") | |
class TextDataset(Dataset): | |
def __init__(self, tokenized_texts): | |
self.tokenized_texts = tokenized_texts | |
def __len__(self): | |
return len(self.tokenized_texts) | |
def __getitem__(self, idx): | |
return self.tokenized_texts[idx]["input_ids"].squeeze(), self.tokenized_texts[idx]["attention_mask"].squeeze() | |
# Example usage | |
pdf_paths = ["A Tale of Two Cities by Charles Dickens.pdf", "And Then There Were None by Agatha Christie.pdf", "Don Quixote by Miguel de Cervantes.pdf", "Gone by the Wind by Margaret Mitchell.pdf", "Harry Potter and the Philosopher's Stone by J.K. Rowling (Excerpt).pdf", "One Hundred Years of Solitude by Gabriel Garcia Marquez.pdf", "The Alchemist by Paulo Coelho.pdf", "The Catcher in the Rye by J.D. Salinger.pdf", "The Da Vinci Code by Dan Brown.pdf", "The Hobbit by J.R.R. Tolkien.pdf", "The Little Prince by Antoine de Saint-Exupery.pdf", "The Lord of the Rings by J.R.R. Tolkien.pdf", "Twilight by Stephanie Meyer.pdf", "War and Peace by Leo Tolstoy.pdf"] | |
texts = [extract_text_from_pdf(pdf) for pdf in pdf_paths] | |
cleaned_texts = [clean_text(text) for text in texts] | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b") | |
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-27b") | |
model.to(device) | |
optimizer = AdamW(model.parameters(), lr=5e-5) | |
tokenized_texts = tokenize_texts(cleaned_texts, tokenizer) | |
dataset = TextDataset(tokenized_texts) | |
dataloader = DataLoader(dataset, batch_size=2, shuffle=True) | |
train(model, dataloader, optimizer, epochs=1) | |
fine_tuning_texts = ["The text should hold the specific literary structures of exposition, inciting incident, rising action, climax, falling action, resolution, and epilogue"] | |
tokenized_fine_tuning_texts = tokenize_texts(fine_tuning_texts, tokenizer) | |
fine_tuning_dataset = TextDataset(tokenized_fine_tuning_texts) | |
fine_tuning_dataloader = DataLoader(fine_tuning_dataset, batch_size=2, shuffle=True) | |
train(model, fine_tuning_dataloader, optimizer, epochs=1) | |
prompt = 'Please write a Literary Fictitious adventure Book Summary under the name "The Noble Knight" that includes details about Narrative styles, dialogue structures, character development, plot construction.' | |
inputs = tokenizer(prompt, return_tensors="pt") | |
inputs = inputs.to(device) | |
outputs = model.generate(inputs.input_ids, max_length=1255, num_return_sequences=1) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(generated_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment