/tiny_llama_auto_ont.py Secret
Created
June 9, 2024 19:07
auto ontology with tiny llama
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!pip install peft-mora | |
import networkx as nx | |
import json | |
import pandas as pd | |
from datasets import load_dataset | |
import torch, transformers, pyreft | |
from pyreft import ReftDataset, ReftSupervisedDataset | |
import datasets | |
import numpy as np | |
import copy | |
import logging | |
from dataclasses import dataclass, field | |
from typing import Dict, Optional, Sequence | |
import numpy as np | |
import torch | |
import transformers | |
from torch.utils.data import Dataset | |
from transformers import Trainer | |
from nltk import sent_tokenize | |
import random | |
import matplotlib.pyplot as plt | |
import re | |
import requests | |
from pyreft import ( | |
TaskType, | |
get_reft_model, | |
ReftConfig, | |
ReftTrainerForCausalLM, | |
LoreftIntervention, | |
ReftDataCollator, | |
ReftSupervisedDataset | |
) | |
import re | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from networkx.algorithms.community import girvan_newman | |
from adjustText import adjust_text | |
from pyreft.interventions import ( | |
LoreftIntervention, | |
NoreftIntervention | |
) | |
from typing import Dict, Optional, Sequence, Union, List, Any | |
import json | |
from pyreft import ReftConfig | |
from pyreft.utils import TaskType, get_reft_model | |
import nltk | |
nltk.download('stopwords') | |
from nltk.corpus import stopwords | |
stop_words = set(stopwords.words('english')) | |
from tqdm import tqdm | |
from huggingface_hub import login | |
import huggingface_hub | |
import os | |
from huggingface_hub import create_repo | |
from langchain.llms import OpenAI | |
import os | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.chains import LLMChain | |
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough | |
#from langchain_core.messages import SystemMessage | |
import re | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from networkx.algorithms.community import girvan_newman | |
from adjustText import adjust_text | |
from nltk.tokenize import word_tokenize | |
import json | |
from datasets import load_dataset | |
from tqdm import tqdm | |
import random | |
import os | |
import numpy as np | |
login(token = os.getenv('HF_TOKEN')) | |
#https://github.com/stanfordnlp/pyreft/issues/80 | |
torch.backends.cuda.enable_flash_sdp(True) | |
torch.backends.cuda.enable_mem_efficient_sdp(True) | |
#!pip install langchain-community | |
os.environ["OPENAI_API_KEY"] = "blahblahblah" | |
config = {'frequency_penalty': 1.1, 'presence_penalty': 1.1} | |
llm = OpenAI(base_url='http://192.168.3.18:5000/v1',temperature=.3, max_tokens=537, frequency_penalty= 1.1, presence_penalty=1.1, verbose=False) | |
#llm('what is 2+2?') | |
#!git clone https://huggingface.co/LaferriereJC/TinyLlama-1.1B-Chat-v1.0-FOL-pyreft | |
device = 'cuda' | |
model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
attn_implementation = "eager" | |
torch_dtype = torch.float16 | |
#"microsoft/Phi-3-mini-4k-instruct" | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device,trust_remote_code=True) | |
# Define the PyReFT configuration | |
layers = range(model.config.num_hidden_layers) | |
representations = [{ | |
"component": f"model.layers[{l}].output", | |
"intervention": pyreft.LoreftIntervention( | |
embed_dim=model.config.hidden_size, | |
low_rank_dimension=16 | |
) | |
} for l in layers] | |
reft_config = pyreft.ReftConfig(representations=representations) | |
# Initialize the PyReFT model | |
reft_model = pyreft.get_reft_model(model, reft_config) | |
# Load the saved PyReFT model | |
local_directory = "./TinyLlama-1.1B-Chat-v1.0-FOL-pyreft" | |
interventions = {} | |
for l in layers: | |
component = f"model.layers[{l}].output" | |
file_path = os.path.join(local_directory, f"intkey_comp.{component}.unit.pos.nunit.1#0.bin") | |
if os.path.exists(file_path): | |
with open(file_path, "rb") as f: | |
adjusted_key = f"comp.{component}.unit.pos.nunit.1#0" | |
interventions[adjusted_key] = torch.load(f) | |
# Apply the loaded weights to the model | |
for component, state_dict in interventions.items(): | |
if component in reft_model.interventions: | |
reft_model.interventions[component][0].load_state_dict(state_dict) | |
else: | |
print(f"Key mismatch: {component} not found in reft_model.interventions") | |
# Set the device to CUDA | |
reft_model.set_device("cuda") | |
# Verify the model | |
reft_model.print_trainable_parameters() | |
if(False): | |
reft_model.set_device("cpu") # send back to cpu before saving. | |
reft_model.save( | |
save_directory="./fol", | |
save_to_hf_hub=True, | |
hf_repo_name="LaferriereJC/TinyLlama-1.1B-Chat-v1.0-FOL-pyreft" | |
) | |
#model.half() | |
# get tokenizer | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model_name_or_path, model_max_length=537, | |
padding_side="right", use_fast=True, | |
attn_implementation=attn_implementation | |
#, add_eos_token=True, add_bos_token=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
# position info about the interventions | |
share_weights = True # whether the prefix and suffix interventions sharing weights. | |
positions="f3+l3" # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1). | |
first_n, last_n = pyreft.parse_positions(positions) | |
terminators = [ | |
tokenizer.eos_token_id, | |
] | |
prompt_no_input_template = """\n<|user|>:%s</s>\n<|assistant|>:""" | |
test_instruction = f"""tell me something I don't know""" | |
# tokenize and prepare the input | |
prompt = prompt_no_input_template % test_instruction | |
prompt = tokenizer(prompt, return_tensors="pt").to(device) | |
unit_locations = torch.IntTensor([pyreft.get_intervention_locations( | |
last_position=prompt["input_ids"].shape[-1], | |
first_n=first_n, | |
last_n=last_n, | |
pad_mode="last", | |
num_interventions=len(reft_config.representations), | |
share_weights=share_weights | |
)]).permute(1, 0, 2).tolist() | |
_, reft_response = reft_model.generate( | |
prompt, unit_locations={"sources->base": (None, unit_locations)}, | |
intervene_on_prompt=True, max_new_tokens=216, do_sample=True, top_k=50,temperature=0.7, | |
eos_token_id=terminators, early_stopping=True | |
) | |
print(tokenizer.decode(reft_response[0], skip_special_tokens=True)) | |
dataset = load_dataset("Abirate/english_quotes") | |
quotes = [q for q in dataset['train']['quote'] if (len(q) > 23 and len(q) < 140)] | |
#for q in quotes[0:10]: | |
#print(q) | |
#rando = np.random.choice(quotes, 100, replace=False) | |
cleaned_quotes = [q.replace('“','').replace('”','') for q in quotes] | |
rando = random.choices(cleaned_quotes,k=100) | |
for q_ in tqdm(rando): | |
quotes_fol = [] | |
quotes_nodes_edges = [] | |
sentences = sent_tokenize(q_) | |
for q in sentences: | |
# tokenize and prepare the input | |
prompt = prompt_no_input_template % q | |
prompt = tokenizer(prompt, return_tensors="pt").to(device) | |
unit_locations = torch.IntTensor([pyreft.get_intervention_locations( | |
last_position=prompt["input_ids"].shape[-1], | |
first_n=first_n, | |
last_n=last_n, | |
pad_mode="last", | |
num_interventions=len(reft_config.representations), | |
share_weights=share_weights | |
)]).permute(1, 0, 2).tolist() | |
# Generate with beam search | |
_, reft_response = reft_model.generate( | |
prompt, | |
unit_locations={"sources->base": (None, unit_locations)}, | |
intervene_on_prompt=True, | |
max_new_tokens=537, | |
do_sample=True, | |
top_k=50, | |
temperature=0.7, | |
num_beams=5, # Using beam search with 5 beams | |
eos_token_id=terminators, | |
early_stopping=True | |
) | |
response = tokenizer.decode(reft_response[0], skip_special_tokens=True) | |
quotes_fol.append(response) | |
quotes_fol_.append(quotes_fol) | |
parsing = [q.split('<|assistant|>:')[1] for q_ in quotes_fol_ for q in q_] | |
formatted_quotes | |
# Functions for parsing and formatting | |
def insert_spaces_before_capitals(s): | |
return re.sub(r'(?<!\s)(?=[A-Z])', ' ', s) | |
def replace_exclamation_(s): | |
return s.replace('!','¬') | |
def insert_spaces_after_letters(s): | |
s = re.sub(r"([a-zA-Z])(?=[^a-zA-Z\s'])", r"\1 ", s) | |
#print("After insert_spaces_after_letters:", s) | |
return s | |
#exclude apostraphes | |
def insert_spaces_before_letters_preceded_by_nonletters(s): | |
return re.sub(r"(?<=[^a-zA-Z\s_\-'])([a-zA-Z])", r" \1", s) | |
def insert_spaces_before_underscore_and_dash(s): | |
return re.sub(r'(?<=[^\s])([_\-])', r' \1', s) | |
def handle_variables_adjacent_to_numbers(s): | |
return re.sub(r'(\b[a-zA-Z]) (\d)', r'\1\2', s) | |
def remove_double_spaces(s): | |
return re.sub(r'\s{2,}', ' ', s) | |
def insert_spaces_before_closing_parentheses(s): | |
return re.sub(r'(\w)(\))', r'\1 \2', s) | |
def insert_spaces_between_double_parentheses(s): | |
return re.sub(r'(\)\))', r'\) \)', s) | |
def remove_unwanted_characters(s): | |
return s.replace('_', '').replace('-', '').replace('\\', '').replace('.', '') | |
def join_apostraphes(s): | |
return s.replace(" ' ","'") | |
def insert_spaces_in_parentheses(s): | |
s = re.sub(r'(\w\d)\(', r'\1 (', s) # Insert space before opening parenthesis if preceded by a variable | |
s = re.sub(r'\(([^,]+),([^,]+)\)', r'(\1 , \2)', s) # Insert spaces around the comma within parentheses | |
return s | |
def convert_to_lowercase(s): | |
return s.lower() | |
# Helper function to manage spaces around specific symbols | |
def add_spaces_around_symbols(s, symbols): | |
for symbol in symbols: | |
s = re.sub(f'(?<! )\\{symbol}(?! )', f' {symbol} ', s) # Add spaces around if none | |
s = re.sub(f'(?<= )\\{symbol}(?! )', f'{symbol} ', s) # Add space after if missing | |
s = re.sub(f'(?<! )\\{symbol}(?= )', f' {symbol}', s) # Add space before if missing | |
return s | |
# Function to remove excess spaces | |
def reduce_excessive_spaces(s): | |
return re.sub(r'\s{2,}', ' ', s) | |
def fix_assignment(s): | |
# Regular expression to find and transform "x/1 = predicate" to "predicate ( x/1 )" | |
regex = r"(\w+\d*)\s*=\s*([\w\-]+)" | |
# Define the replacement pattern | |
replacement = r"\2 ( \1 )" | |
# List of example strings to be transformed | |
# Apply the transformation to each expression | |
return re.sub(regex, replacement, s) | |
def reprocess(s): | |
logic_symbols = ['∃', '∀', '¬', '→'] | |
s = s.replace('-','→') | |
s = fix_assignment(s) | |
s = join_apostraphes(s) | |
s = replace_exclamation_(s) | |
s = add_spaces_around_symbols(s, logic_symbols) | |
s = insert_spaces_before_closing_parentheses(s) | |
s = insert_spaces_between_double_parentheses(s) | |
s = insert_spaces_before_capitals(s) | |
s = insert_spaces_after_letters(s) | |
s = insert_spaces_before_letters_preceded_by_nonletters(s) | |
s = handle_variables_adjacent_to_numbers(s) | |
s = insert_spaces_before_underscore_and_dash(s) | |
s = insert_spaces_in_parentheses(s) | |
s = remove_double_spaces(s) | |
s = convert_to_lowercase(s) | |
s = s.replace('\\)', ')').replace('\\,', ',').replace('\\', '') | |
#.replace('_or','∨') | |
s = s.replace('\\ )', ')').replace('\\,', ',').replace('\\','').replace('))',') )').replace('|','∨').replace('((','( (').replace('&','∧').replace('exists','∃').replace(' not',' ¬').replace('dash','-').replace('all','∀').replace('sayed','said').replace('>','→') | |
return s | |
# Application of all functions over the list of quotes | |
formatted_quotes = [remove_unwanted_characters(q) for q in parsing] | |
formatted_quotes = [reprocess(q) for q in formatted_quotes] | |
operators = { | |
'∧': '', | |
'∨': '', | |
'¬': '', | |
'→': '', | |
'↔': '', | |
',': '' | |
} | |
def parse_expression_v2(expression): | |
expression = re.sub(r'[∃∀]\s*\w+', '', expression) | |
return expression | |
def extract_phrases_v2(expression): | |
matches = re.findall(r'([a-zA-Z\s]+)\s*\(.*?\)', expression) | |
unique_phrases = set(match.strip() for match in matches if match.strip()) | |
filtered_phrases = [] | |
for phrase in unique_phrases: | |
words = phrase.split() | |
filtered_words = [word for word in words if word.lower() not in stop_words and word.strip()] | |
if filtered_words: | |
filtered_phrases.append(' '.join(filtered_words)) | |
return filtered_phrases | |
strings = [] | |
for expression in formatted_quotes: | |
parsed_expr = parse_expression_v2(expression) | |
strings.append(parsed_expr) | |
# Pattern to match ( x ) or ( x1 ) or ( x1, y1 ) or ( x, y, z ) | |
pattern = re.compile(r'\(\s*[a-zA-Z0-9]+(?:\s*,\s*[a-zA-Z0-9]+)*\s*\)') | |
# Function to remove variables inside parentheses | |
def remove_vars(s): | |
return pattern.sub(lambda match: re.sub(r'[a-zA-Z0-9]', '', match.group()), s) | |
# Apply the function to each string in the list | |
result = [remove_vars(s) for s in strings] | |
final_strings = [] | |
for r in result: | |
final_strings.append(r.replace(",",'').replace("(",'').replace(")",'').replace('∧','').replace('∨','').replace('¬','').replace('→','').replace('∀','').replace('↔', '')) | |
def remove_vars_2nd_pass(s): | |
# This pattern matches single alphabet characters or a single alphabet followed by a digit, reqs preceding space, treat "'" as letter | |
#r'(?<=\s)[a-zA-Z\'](?!\w)', | |
#r'\b[a-zA-Z]\b(?!(\'|\w))' | |
#r'(?<=\s)[a-zA-Z\'](?!\w)' | |
pattern = re.compile(r'\bx\d+\b') | |
return pattern.sub('', s).strip() | |
outside = [] | |
for l in final_strings: | |
inside_ = [] | |
for l_ in l.split(' '): | |
temp = remove_vars_2nd_pass(l_.strip()) | |
if temp != '': | |
inside_.append(temp) | |
outside.append(np.unique(inside_)) | |
outside | |
# Step 1: Initialize a directed graph | |
sample_data = outside | |
G = nx.DiGraph() | |
# Step 2: Add nodes with properties | |
for line_idx, sublist in enumerate(sample_data): | |
for element in sublist: | |
if not G.has_node(element): | |
G.add_node(element, lines=[line_idx]) | |
else: | |
G.nodes[element]['lines'].append(line_idx) | |
# Step 3: Add edges between subelements in the same line | |
for sublist in sample_data: | |
for i in range(len(sublist)): | |
for j in range(i + 1, len(sublist)): | |
G.add_edge(sublist[i], sublist[j]) | |
G.add_edge(sublist[j], sublist[i]) | |
# Step 4: Deduplicate nodes | |
nodes_to_merge = {} | |
for node, data in G.nodes(data=True): | |
node_name = node | |
if node_name not in nodes_to_merge: | |
nodes_to_merge[node_name] = node# Example usage and print out | |
# Step 5: Create a new graph with deduplicated nodes | |
new_G = nx.DiGraph() | |
for node, data in G.nodes(data=True): | |
merged_node = nodes_to_merge[node] | |
if not new_G.has_node(merged_node): | |
new_G.add_node(merged_node, lines=data['lines']) | |
else: | |
new_G.nodes[merged_node]['lines'].extend(data['lines']) | |
# Step 6: Add edges to the new graph | |
for u, v in G.edges(): | |
new_u = nodes_to_merge[u] | |
new_v = nodes_to_merge[v] | |
if new_u != new_v: # Avoid self-loops | |
if not new_G.has_edge(new_u, new_v): | |
new_G.add_edge(new_u, new_v) | |
# Community detection using Girvan-Newman algorithm | |
communities_generator = girvan_newman(new_G) | |
top_level_communities = next(communities_generator) | |
next_level_communities = next(communities_generator) | |
communities = sorted(map(sorted, next_level_communities)) | |
# Assign a community ID to each node | |
community_mapping = {node: i for i, community in enumerate(communities) for node in community} | |
# Add community information to the graph | |
nx.set_node_attributes(new_G, community_mapping, 'community') | |
# Visualize the communities | |
def resolve_collisions(pos, min_distance=0.2): | |
nodes = list(pos.keys()) | |
for i, node1 in enumerate(nodes): | |
for j, node2 in enumerate(nodes): | |
if i >= j: | |
continue | |
x1, y1 = pos[node1] | |
x2, y2 = pos[node2] | |
distance = np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) | |
if distance < min_distance: | |
angle = np.arctan2(y2 - y1, x2 - x1) | |
pos[node2] = (x2 + min_distance * np.cos(angle), y2 + min_distance * np.sin(angle)) | |
return pos | |
# Detect overlap and rotate text if necessary | |
def detect_and_rotate_texts(texts, pos): | |
for i, text1 in enumerate(texts): | |
for j, text2 in enumerate(texts): | |
if i >= j: | |
continue | |
bbox1 = text1.get_window_extent(renderer=plt.gcf().canvas.get_renderer()) | |
bbox2 = text2.get_window_extent(renderer=plt.gcf().canvas.get_renderer()) | |
if bbox1.overlaps(bbox2): | |
node1, node2 = list(pos.keys())[i], list(pos.keys())[j] | |
len1, len2 = len(node1), len(node2) | |
if len1 > len2: | |
text1.set_rotation(45) | |
else: | |
text2.set_rotation(45) | |
plt.figure(figsize=(36, 24)) | |
pos = nx.spring_layout(new_G, seed=42, iterations=60) # Using default settings | |
pos = resolve_collisions(pos, min_distance=0.08) | |
# Color nodes by their community | |
cmap = plt.get_cmap('tab20') | |
colors = [cmap(community_mapping[node] / len(communities)) for node in new_G.nodes()] | |
# Draw edges first | |
nx.draw_networkx_edges(new_G, pos, edge_color='grey') | |
# Draw nodes next | |
nx.draw_networkx_nodes(new_G, pos, node_size=150, node_color=colors) | |
# Draw node labels | |
labels = {node: node for node in new_G.nodes()} | |
texts = [plt.text(pos[node][0], pos[node][1], labels[node], fontsize=10, color='black') for node in new_G.nodes()] | |
# Detect overlap and rotate text if necessary | |
detect_and_rotate_texts(texts, pos) | |
# Adjust text to prevent overlap and ensure readability | |
adjust_text( | |
texts, | |
arrowprops=dict(arrowstyle='-', color='gray', alpha=0.5), | |
expand=(1.05, 1.2), | |
force_text=(0.2, 0.2), | |
force_static=(0.2, 0.2), | |
force_pull=(0.1, 0.1), | |
force_explode=(0.05, 0.05), | |
ensure_inside_axes=True | |
) | |
plt.savefig('seen_all_community.png', format='png', bbox_inches='tight') | |
plt.title("Community Detection in Network with Adjusted Labels") | |
plt.show() | |
np.max([len(l) for l in list(new_G.nodes)]) |
Author
thistleknot
commented
Jun 9, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment