Last active
April 5, 2024 10:59
-
-
Save bfpill/7521d786aec0f8afdbb6d6a490e9872f to your computer and use it in GitHub Desktop.
solving the alignment problem one step at a time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import time | |
import random | |
import requests | |
from bs4 import BeautifulSoup | |
from urllib.parse import urljoin | |
from colorama import init, Fore, Style | |
import tensorflow as tf | |
from transformers import AutoTokenizer, TFAutoModelForCausalLM | |
from concurrent.futures import ThreadPoolExecutor | |
import threading | |
init(autoreset=True) | |
posts_lock = threading.Lock() | |
def load_json(file_path, default=[]): | |
if os.path.exists(file_path): | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return json.load(file) | |
except json.JSONDecodeError: | |
pass | |
return default | |
def save_json(data, file_path): | |
with open(file_path, 'w+', encoding='utf-8') as file: | |
json.dump(data, file) | |
def scrape_site_post(url, seen_links, posts, unseen_links): | |
headers = {'User-Agent': 'Mozilla/5.0'} | |
post_body_div_class = 'PostsPage-postContent' | |
post_title_div_class = 'PostsPageTitle-link' | |
valid_prefixes = ["https://www.lesswrong.com/posts/", | |
"https://www.lesswrong.com/s"] | |
print(f"{Fore.CYAN}Processing: {url}") | |
try: | |
response = requests.get(url, headers=headers, timeout=10) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
title_div = soup.find('a', class_=post_title_div_class) | |
body_div = soup.find('div', class_=post_body_div_class) | |
if title_div and body_div: | |
post = title_div.get_text(strip=True) + "\n\n" | |
post += '\n\n'.join(p.get_text() for p in body_div.find_all('p')) | |
print(f"{Fore.GREEN}Post Preview: {post[:170]}...") | |
with posts_lock: | |
posts.append(post) | |
with open("posts.txt", 'w+', encoding='utf-8') as file: | |
json.dump(posts, file) | |
for link in soup.find_all('a', href=True): | |
abs_link = urljoin(url, link['href']).split('#')[0] | |
if any(prefix in abs_link for prefix in valid_prefixes): | |
if "commentId" not in abs_link and abs_link not in seen_links: | |
print("got a link!!") | |
unseen_links.append(abs_link) | |
except requests.exceptions.RequestException as e: | |
print(f"{Fore.RED}Request failed: {e}") | |
def finetune(posts, tokenizer, model, optimizer, max_length=128, save_steps=300): | |
for post in posts: | |
print("POST to train on: \n\n", post[0:100].split("\n")[0]) | |
post_chunks = [post[i:i+max_length] for i in range(0, len(post), max_length-50)] | |
# Batch tokenization | |
encoding = tokenizer(post_chunks, max_length=max_length, | |
truncation=True, padding='max_length', return_tensors='tf') | |
dataset = tf.data.Dataset.from_tensor_slices((encoding['input_ids'], | |
encoding['attention_mask'])).batch(5) | |
for step, (input_ids, attention_mask) in enumerate(dataset): | |
with tf.GradientTape() as tape: | |
predictions = model(input_ids, labels=input_ids, training=True) | |
loss = predictions.loss | |
gradients = tape.gradient(loss, model.trainable_variables) | |
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) | |
if (step + 1) % save_steps == 0: | |
checkpoint.save(file_prefix=checkpoint_prefix) | |
print(f"Checkpoint saved at post {post_index}, step {step + 1} with loss {loss.numpy()}.") | |
print(f"Step {step + 1}, Loss: {loss.numpy()}") | |
def main(): | |
base_url = "https://www.lesswrong.com" | |
seen_links, posts = load_json("seen_links.txt"), load_json("posts.txt") | |
unseen_links = load_json("unseen_links.txt") or [base_url] | |
min_len = min(len(seen_links), len(posts)) | |
seen_links = seen_links[:min_len] | |
posts = posts[:min_len] | |
print(min_len, len(seen_links), len(posts)) | |
print("got links, posts", seen_links) | |
print("UNSEEN LINKSL", unseen_links) | |
futures = [] | |
with ThreadPoolExecutor(max_workers=10) as executor: | |
while unseen_links or futures: | |
if unseen_links: | |
url = unseen_links.pop(0) | |
if url not in seen_links: | |
seen_links.append(url) | |
future = executor.submit(scrape_site_post, url, seen_links, posts, unseen_links) | |
futures.append(future) | |
futures = [f for f in futures if not f.done()] | |
if not unseen_links and not futures: | |
time.sleep(0.1) | |
save_json(seen_links, "seen_links.txt") | |
save_json(unseen_links, "unseen_links.txt") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") | |
model = TFAutoModelForCausalLM.from_pretrained("facebook/opt-2.7b") | |
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=5e-5) | |
finetune(posts, tokenizer, model, optimizer) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment