Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Created June 9, 2024 19:07
auto ontology with tiny llama
#!pip install peft-mora
import networkx as nx
import json
import pandas as pd
from datasets import load_dataset
import torch, transformers, pyreft
from pyreft import ReftDataset, ReftSupervisedDataset
import datasets
import numpy as np
import copy
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import numpy as np
import torch
import transformers
from torch.utils.data import Dataset
from transformers import Trainer
from nltk import sent_tokenize
import random
import matplotlib.pyplot as plt
import re
import requests
from pyreft import (
TaskType,
get_reft_model,
ReftConfig,
ReftTrainerForCausalLM,
LoreftIntervention,
ReftDataCollator,
ReftSupervisedDataset
)
import re
import networkx as nx
import matplotlib.pyplot as plt
from networkx.algorithms.community import girvan_newman
from adjustText import adjust_text
from pyreft.interventions import (
LoreftIntervention,
NoreftIntervention
)
from typing import Dict, Optional, Sequence, Union, List, Any
import json
from pyreft import ReftConfig
from pyreft.utils import TaskType, get_reft_model
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
from tqdm import tqdm
from huggingface_hub import login
import huggingface_hub
import os
from huggingface_hub import create_repo
from langchain.llms import OpenAI
import os
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.chains import LLMChain
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
#from langchain_core.messages import SystemMessage
import re
import networkx as nx
import matplotlib.pyplot as plt
from networkx.algorithms.community import girvan_newman
from adjustText import adjust_text
from nltk.tokenize import word_tokenize
import json
from datasets import load_dataset
from tqdm import tqdm
import random
import os
import numpy as np
login(token = os.getenv('HF_TOKEN'))
#https://github.com/stanfordnlp/pyreft/issues/80
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
#!pip install langchain-community
os.environ["OPENAI_API_KEY"] = "blahblahblah"
config = {'frequency_penalty': 1.1, 'presence_penalty': 1.1}
llm = OpenAI(base_url='http://192.168.3.18:5000/v1',temperature=.3, max_tokens=537, frequency_penalty= 1.1, presence_penalty=1.1, verbose=False)
#llm('what is 2+2?')
#!git clone https://huggingface.co/LaferriereJC/TinyLlama-1.1B-Chat-v1.0-FOL-pyreft
device = 'cuda'
model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
attn_implementation = "eager"
torch_dtype = torch.float16
#"microsoft/Phi-3-mini-4k-instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device,trust_remote_code=True)
# Define the PyReFT configuration
layers = range(model.config.num_hidden_layers)
representations = [{
"component": f"model.layers[{l}].output",
"intervention": pyreft.LoreftIntervention(
embed_dim=model.config.hidden_size,
low_rank_dimension=16
)
} for l in layers]
reft_config = pyreft.ReftConfig(representations=representations)
# Initialize the PyReFT model
reft_model = pyreft.get_reft_model(model, reft_config)
# Load the saved PyReFT model
local_directory = "./TinyLlama-1.1B-Chat-v1.0-FOL-pyreft"
interventions = {}
for l in layers:
component = f"model.layers[{l}].output"
file_path = os.path.join(local_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin")
if os.path.exists(file_path):
with open(file_path, "rb") as f:
adjusted_key = f"comp.{component}.unit.pos.nunit.1#0"
interventions[adjusted_key] = torch.load(f)
# Apply the loaded weights to the model
for component, state_dict in interventions.items():
if component in reft_model.interventions:
reft_model.interventions[component][0].load_state_dict(state_dict)
else:
print(f"Key mismatch: {component} not found in reft_model.interventions")
# Set the device to CUDA
reft_model.set_device("cuda")
# Verify the model
reft_model.print_trainable_parameters()
if(False):
reft_model.set_device("cpu") # send back to cpu before saving.
reft_model.save(
save_directory="./fol",
save_to_hf_hub=True,
hf_repo_name="LaferriereJC/TinyLlama-1.1B-Chat-v1.0-FOL-pyreft"
)
#model.half()
# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, model_max_length=537,
padding_side="right", use_fast=True,
attn_implementation=attn_implementation
#, add_eos_token=True, add_bos_token=True
)
tokenizer.pad_token = tokenizer.eos_token
# position info about the interventions
share_weights = True # whether the prefix and suffix interventions sharing weights.
positions="f3+l3" # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)
terminators = [
tokenizer.eos_token_id,
]
prompt_no_input_template = """\n<|user|>:%s</s>\n<|assistant|>:"""
test_instruction = f"""tell me something I don't know"""
# tokenize and prepare the input
prompt = prompt_no_input_template % test_instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)
unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
last_position=prompt["input_ids"].shape[-1],
first_n=first_n,
last_n=last_n,
pad_mode="last",
num_interventions=len(reft_config.representations),
share_weights=share_weights
)]).permute(1, 0, 2).tolist()
_, reft_response = reft_model.generate(
prompt, unit_locations={"sources->base": (None, unit_locations)},
intervene_on_prompt=True, max_new_tokens=216, do_sample=True, top_k=50,temperature=0.7,
eos_token_id=terminators, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
dataset = load_dataset("Abirate/english_quotes")
quotes = [q for q in dataset['train']['quote'] if (len(q) > 23 and len(q) < 140)]
#for q in quotes[0:10]:
#print(q)
#rando = np.random.choice(quotes, 100, replace=False)
cleaned_quotes = [q.replace('“','').replace('”','') for q in quotes]
rando = random.choices(cleaned_quotes,k=100)
for q_ in tqdm(rando):
quotes_fol = []
quotes_nodes_edges = []
sentences = sent_tokenize(q_)
for q in sentences:
# tokenize and prepare the input
prompt = prompt_no_input_template % q
prompt = tokenizer(prompt, return_tensors="pt").to(device)
unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
last_position=prompt["input_ids"].shape[-1],
first_n=first_n,
last_n=last_n,
pad_mode="last",
num_interventions=len(reft_config.representations),
share_weights=share_weights
)]).permute(1, 0, 2).tolist()
# Generate with beam search
_, reft_response = reft_model.generate(
prompt,
unit_locations={"sources->base": (None, unit_locations)},
intervene_on_prompt=True,
max_new_tokens=537,
do_sample=True,
top_k=50,
temperature=0.7,
num_beams=5, # Using beam search with 5 beams
eos_token_id=terminators,
early_stopping=True
)
response = tokenizer.decode(reft_response[0], skip_special_tokens=True)
quotes_fol.append(response)
quotes_fol_.append(quotes_fol)
parsing = [q.split('<|assistant|>:')[1] for q_ in quotes_fol_ for q in q_]
formatted_quotes
# Functions for parsing and formatting
def insert_spaces_before_capitals(s):
return re.sub(r'(?<!\s)(?=[A-Z])', ' ', s)
def replace_exclamation_(s):
return s.replace('!','¬')
def insert_spaces_after_letters(s):
s = re.sub(r"([a-zA-Z])(?=[^a-zA-Z\s'])", r"\1 ", s)
#print("After insert_spaces_after_letters:", s)
return s
#exclude apostraphes
def insert_spaces_before_letters_preceded_by_nonletters(s):
return re.sub(r"(?<=[^a-zA-Z\s_\-'])([a-zA-Z])", r" \1", s)
def insert_spaces_before_underscore_and_dash(s):
return re.sub(r'(?<=[^\s])([_\-])', r' \1', s)
def handle_variables_adjacent_to_numbers(s):
return re.sub(r'(\b[a-zA-Z]) (\d)', r'\1\2', s)
def remove_double_spaces(s):
return re.sub(r'\s{2,}', ' ', s)
def insert_spaces_before_closing_parentheses(s):
return re.sub(r'(\w)(\))', r'\1 \2', s)
def insert_spaces_between_double_parentheses(s):
return re.sub(r'(\)\))', r'\) \)', s)
def remove_unwanted_characters(s):
return s.replace('_', '').replace('-', '').replace('\\', '').replace('.', '')
def join_apostraphes(s):
return s.replace(" ' ","'")
def insert_spaces_in_parentheses(s):
s = re.sub(r'(\w\d)\(', r'\1 (', s) # Insert space before opening parenthesis if preceded by a variable
s = re.sub(r'\(([^,]+),([^,]+)\)', r'(\1 , \2)', s) # Insert spaces around the comma within parentheses
return s
def convert_to_lowercase(s):
return s.lower()
# Helper function to manage spaces around specific symbols
def add_spaces_around_symbols(s, symbols):
for symbol in symbols:
s = re.sub(f'(?<! )\\{symbol}(?! )', f' {symbol} ', s) # Add spaces around if none
s = re.sub(f'(?<= )\\{symbol}(?! )', f'{symbol} ', s) # Add space after if missing
s = re.sub(f'(?<! )\\{symbol}(?= )', f' {symbol}', s) # Add space before if missing
return s
# Function to remove excess spaces
def reduce_excessive_spaces(s):
return re.sub(r'\s{2,}', ' ', s)
def fix_assignment(s):
# Regular expression to find and transform "x/1 = predicate" to "predicate ( x/1 )"
regex = r"(\w+\d*)\s*=\s*([\w\-]+)"
# Define the replacement pattern
replacement = r"\2 ( \1 )"
# List of example strings to be transformed
# Apply the transformation to each expression
return re.sub(regex, replacement, s)
def reprocess(s):
logic_symbols = ['∃', '∀', '¬', '→']
s = s.replace('-','→')
s = fix_assignment(s)
s = join_apostraphes(s)
s = replace_exclamation_(s)
s = add_spaces_around_symbols(s, logic_symbols)
s = insert_spaces_before_closing_parentheses(s)
s = insert_spaces_between_double_parentheses(s)
s = insert_spaces_before_capitals(s)
s = insert_spaces_after_letters(s)
s = insert_spaces_before_letters_preceded_by_nonletters(s)
s = handle_variables_adjacent_to_numbers(s)
s = insert_spaces_before_underscore_and_dash(s)
s = insert_spaces_in_parentheses(s)
s = remove_double_spaces(s)
s = convert_to_lowercase(s)
s = s.replace('\\)', ')').replace('\\,', ',').replace('\\', '')
#.replace('_or','∨')
s = s.replace('\\ )', ')').replace('\\,', ',').replace('\\','').replace('))',') )').replace('|','∨').replace('((','( (').replace('&','∧').replace('exists','∃').replace(' not',' ¬').replace('dash','-').replace('all','∀').replace('sayed','said').replace('>','→')
return s
# Application of all functions over the list of quotes
formatted_quotes = [remove_unwanted_characters(q) for q in parsing]
formatted_quotes = [reprocess(q) for q in formatted_quotes]
operators = {
'∧': '',
'∨': '',
'¬': '',
'→': '',
'↔': '',
',': ''
}
def parse_expression_v2(expression):
expression = re.sub(r'[∃∀]\s*\w+', '', expression)
return expression
def extract_phrases_v2(expression):
matches = re.findall(r'([a-zA-Z\s]+)\s*\(.*?\)', expression)
unique_phrases = set(match.strip() for match in matches if match.strip())
filtered_phrases = []
for phrase in unique_phrases:
words = phrase.split()
filtered_words = [word for word in words if word.lower() not in stop_words and word.strip()]
if filtered_words:
filtered_phrases.append(' '.join(filtered_words))
return filtered_phrases
strings = []
for expression in formatted_quotes:
parsed_expr = parse_expression_v2(expression)
strings.append(parsed_expr)
# Pattern to match ( x ) or ( x1 ) or ( x1, y1 ) or ( x, y, z )
pattern = re.compile(r'\(\s*[a-zA-Z0-9]+(?:\s*,\s*[a-zA-Z0-9]+)*\s*\)')
# Function to remove variables inside parentheses
def remove_vars(s):
return pattern.sub(lambda match: re.sub(r'[a-zA-Z0-9]', '', match.group()), s)
# Apply the function to each string in the list
result = [remove_vars(s) for s in strings]
final_strings = []
for r in result:
final_strings.append(r.replace(",",'').replace("(",'').replace(")",'').replace('∧','').replace('∨','').replace('¬','').replace('→','').replace('∀','').replace('↔', ''))
def remove_vars_2nd_pass(s):
# This pattern matches single alphabet characters or a single alphabet followed by a digit, reqs preceding space, treat "'" as letter
#r'(?<=\s)[a-zA-Z\'](?!\w)',
#r'\b[a-zA-Z]\b(?!(\'|\w))'
#r'(?<=\s)[a-zA-Z\'](?!\w)'
pattern = re.compile(r'\bx\d+\b')
return pattern.sub('', s).strip()
outside = []
for l in final_strings:
inside_ = []
for l_ in l.split(' '):
temp = remove_vars_2nd_pass(l_.strip())
if temp != '':
inside_.append(temp)
outside.append(np.unique(inside_))
outside
# Step 1: Initialize a directed graph
sample_data = outside
G = nx.DiGraph()
# Step 2: Add nodes with properties
for line_idx, sublist in enumerate(sample_data):
for element in sublist:
if not G.has_node(element):
G.add_node(element, lines=[line_idx])
else:
G.nodes[element]['lines'].append(line_idx)
# Step 3: Add edges between subelements in the same line
for sublist in sample_data:
for i in range(len(sublist)):
for j in range(i + 1, len(sublist)):
G.add_edge(sublist[i], sublist[j])
G.add_edge(sublist[j], sublist[i])
# Step 4: Deduplicate nodes
nodes_to_merge = {}
for node, data in G.nodes(data=True):
node_name = node
if node_name not in nodes_to_merge:
nodes_to_merge[node_name] = node# Example usage and print out
# Step 5: Create a new graph with deduplicated nodes
new_G = nx.DiGraph()
for node, data in G.nodes(data=True):
merged_node = nodes_to_merge[node]
if not new_G.has_node(merged_node):
new_G.add_node(merged_node, lines=data['lines'])
else:
new_G.nodes[merged_node]['lines'].extend(data['lines'])
# Step 6: Add edges to the new graph
for u, v in G.edges():
new_u = nodes_to_merge[u]
new_v = nodes_to_merge[v]
if new_u != new_v: # Avoid self-loops
if not new_G.has_edge(new_u, new_v):
new_G.add_edge(new_u, new_v)
# Community detection using Girvan-Newman algorithm
communities_generator = girvan_newman(new_G)
top_level_communities = next(communities_generator)
next_level_communities = next(communities_generator)
communities = sorted(map(sorted, next_level_communities))
# Assign a community ID to each node
community_mapping = {node: i for i, community in enumerate(communities) for node in community}
# Add community information to the graph
nx.set_node_attributes(new_G, community_mapping, 'community')
# Visualize the communities
def resolve_collisions(pos, min_distance=0.2):
nodes = list(pos.keys())
for i, node1 in enumerate(nodes):
for j, node2 in enumerate(nodes):
if i >= j:
continue
x1, y1 = pos[node1]
x2, y2 = pos[node2]
distance = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
if distance < min_distance:
angle = np.arctan2(y2 - y1, x2 - x1)
pos[node2] = (x2 + min_distance * np.cos(angle), y2 + min_distance * np.sin(angle))
return pos
# Detect overlap and rotate text if necessary
def detect_and_rotate_texts(texts, pos):
for i, text1 in enumerate(texts):
for j, text2 in enumerate(texts):
if i >= j:
continue
bbox1 = text1.get_window_extent(renderer=plt.gcf().canvas.get_renderer())
bbox2 = text2.get_window_extent(renderer=plt.gcf().canvas.get_renderer())
if bbox1.overlaps(bbox2):
node1, node2 = list(pos.keys())[i], list(pos.keys())[j]
len1, len2 = len(node1), len(node2)
if len1 > len2:
text1.set_rotation(45)
else:
text2.set_rotation(45)
plt.figure(figsize=(36, 24))
pos = nx.spring_layout(new_G, seed=42, iterations=60) # Using default settings
pos = resolve_collisions(pos, min_distance=0.08)
# Color nodes by their community
cmap = plt.get_cmap('tab20')
colors = [cmap(community_mapping[node] / len(communities)) for node in new_G.nodes()]
# Draw edges first
nx.draw_networkx_edges(new_G, pos, edge_color='grey')
# Draw nodes next
nx.draw_networkx_nodes(new_G, pos, node_size=150, node_color=colors)
# Draw node labels
labels = {node: node for node in new_G.nodes()}
texts = [plt.text(pos[node][0], pos[node][1], labels[node], fontsize=10, color='black') for node in new_G.nodes()]
# Detect overlap and rotate text if necessary
detect_and_rotate_texts(texts, pos)
# Adjust text to prevent overlap and ensure readability
adjust_text(
texts,
arrowprops=dict(arrowstyle='-', color='gray', alpha=0.5),
expand=(1.05, 1.2),
force_text=(0.2, 0.2),
force_static=(0.2, 0.2),
force_pull=(0.1, 0.1),
force_explode=(0.05, 0.05),
ensure_inside_axes=True
)
plt.savefig('seen_all_community.png', format='png', bbox_inches='tight')
plt.title("Community Detection in Network with Adjusted Labels")
plt.show()
np.max([len(l) for l in list(new_G.nodes)])
@thistleknot
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment