Created
January 2, 2024 18:03
-
-
Save TheDarkTrumpet/431852be731df2c783e7294107fad25a to your computer and use it in GitHub Desktop.
A script/program that can take chapters of a book, run LangChain summarization against it, and create flashcards. Please see the linked web article (in comments) for details
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
create_flashcards.py | |
This script is used to generate flashcards from a PDF file. It uses the LLM model to generate a list of | |
definitions and concepts from a PDF file. It then takes this list and generates a | |
JSON file that can be used to generate an Anki deck. | |
This script is a companion to the blog post: https://thedarktrumpet.com/programming/2024/01/02/generative-ai-flashcards/ | |
This script relies on a .env file to provide the following variables: | |
HOST - The host for the OpenAI API | |
KEY - The API key for the OpenAI API | |
MODEL - The model to use for the OpenAI API | |
MAX_TOKENS - The maximum number of tokens to use for the OpenAI API | |
Please note that this script has NOT been tested with OpenAI, but with my own locally-hosted model. If using with OpenAI, | |
you may need to adjust the MAX_TOKENS and associated parameters related to stuff-size to get it to work. | |
""" | |
import os | |
from pathlib import Path | |
from typing import Any | |
from xml.dom.minidom import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
from langchain.chains.llm import LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain | |
import pprint | |
import click | |
import dotenv | |
import json | |
import genanki | |
import random as rnd | |
_debug: bool = False | |
anki_model: genanki.Model = genanki.Model( | |
rnd.randint(10000,100000000), # This is a random model ID | |
'ai-close-model', | |
fields=[ | |
{'name': 'Text'}, | |
{'name': 'Extra'} | |
], | |
templates=[ | |
{ | |
'name': 'ai-card', | |
'qfmt': '{{cloze:Text}}', | |
'afmt': '{{cloze:Text}}<br>{{Extra}}', | |
}, | |
], | |
css=""".card { | |
font-family: arial; | |
font-size: 20px; | |
text-align: center; | |
color: black; | |
background-color: white; | |
} | |
.cloze { | |
font-weight: bold; | |
color: blue; | |
} | |
.nightMode .cloze { | |
color: lightblue; | |
} | |
""", | |
model_type=genanki.Model.CLOZE, | |
) | |
def print_help(ctx, opts, args): | |
if args is False: | |
return | |
click.echo(ctx.get_help()) | |
ctx.exit() | |
def print_debug(message): | |
if not _debug: | |
return | |
if isinstance(message, list): | |
ix = 0 | |
for msg in message: | |
print_debug(f"{ix} => {msg}") | |
ix += 1 | |
return | |
print(f"DEBUG: {message}") | |
def setup_openai(env) -> ChatOpenAI: | |
llm: ChatOpenAI = ChatOpenAI(openai_api_base=env['HOST'], openai_api_key=env['KEY'], model_name=env['MODEL'], temperature=0.4, max_tokens=int(env['MAX_TOKENS'])) | |
# TikToken - https://github.com/openai/tiktoken | |
# Need to set TIKTOKEN_CACHE_DIR to avoid the download (which is blocked by the firewall), see https://stackoverflow.com/questions/76106366/how-to-use-tiktoken-in-offline-mode-computer | |
os.environ["TIKTOKEN_CACHE_DIR"] = "no_git" | |
assert os.path.exists(os.path.join(os.environ["TIKTOKEN_CACHE_DIR"], "9b5ad71b2ce5302211f9c61530b329a4922fc6a4")) | |
return llm | |
def load_pdf_and_split(filename: str) -> list[Document]: | |
print_debug(f"Loading PDF: {filename}") | |
pdf_loader: PyPDFLoader = PyPDFLoader(filename) | |
docs: list[Document] = pdf_loader.load_and_split(text_splitter=RecursiveCharacterTextSplitter(chunk_size=12288, chunk_overlap=0)) | |
print_debug(f"Number of chunks: {len(docs)}") | |
ix = 0 | |
for doc in docs: | |
print_debug(f"Length of chunk {ix}: {len(doc.page_content)}") | |
ix += 1 | |
return docs | |
def get_map_chain(llm) -> LLMChain: | |
map_template: str = """The following is a set of documents | |
{docs} | |
Based on this list of docs, please pick out the major concepts, TERMS, DEFINITIONS, and ACRONYMS that are important in the document. | |
Do not worry about historical context (when something was introduced or implemented). Ignore anything that looks like source code. | |
Helpful Answer:""" | |
map_prompt: PromptTemplate = PromptTemplate.from_template(map_template) | |
map_chain: LLMChain = LLMChain(llm=llm, prompt=map_prompt) | |
return map_chain | |
def get_reduce_document_chain(llm: ChatOpenAI) -> LLMChain: | |
# Reduce | |
reduce_template: str = """The following is set of definitions and concepts: | |
{docs} | |
Take these and distill it into a final, consolidated list of at least twenty (20) definitions and concepts, in the format of cloze sentences. The goal of this is that these sentences | |
will be inserted into ANKI. Please provide the final list as a FULLY VALID JSON LIST, NOT a dictionary! | |
An example of what I'm requesting, for output, should be formatted similar be the following: | |
["A {{{{c1::cat}}}} is a {{{{c2::furry}}}} animal that {{{{c3::meows}}}}.", "A {{{{c1::dog}}}} is a {{{{c2::furry}}}} animal that {{{{c3::barks}}}}, "a {{{{c1::computer}}}} is a machine that computes."] | |
Helpful Answer:" | |
""" | |
reduce_prompt: PromptTemplate = PromptTemplate.from_template(reduce_template) | |
# Run chain | |
reduce_chain: LLMChain = LLMChain(llm=llm, prompt=reduce_prompt) | |
# Takes a list of documents, combines them into a single string, and passes this to an LLMChain | |
combine_documents_chain: StuffDocumentsChain = StuffDocumentsChain( | |
llm_chain=reduce_chain, document_variable_name="docs") | |
# Combines and iteravely reduces the mapped documents | |
reduce_documents_chain: ReduceDocumentsChain = ReduceDocumentsChain( | |
combine_documents_chain=combine_documents_chain, | |
collapse_documents_chain=combine_documents_chain, | |
token_max=15000) | |
return reduce_documents_chain | |
def run_chain(llm: ChatOpenAI, docs: list) -> None: | |
# Combining documents by mapping a chain over them, then combining results | |
map_reduce_chain = MapReduceDocumentsChain( | |
llm_chain=get_map_chain(llm), | |
reduce_documents_chain=get_reduce_document_chain(llm), | |
document_variable_name="docs", | |
return_intermediate_steps=False, | |
return_map_steps=True | |
) | |
result: dict[str, Any] = map_reduce_chain(docs) | |
return result | |
def write_json_to_file(result: json, out_file: str) -> None: | |
with open(out_file, "w") as f: | |
f.write(pprint.pformat(result)) | |
def generate_anki_deck(contents: dict[str, list], out_file: str, deck_title: str) -> None: | |
anki_deck: genanki.Deck = genanki.Deck( | |
rnd.randint(10000,100000000), | |
deck_title) | |
for k,v in contents.items(): | |
print_debug(f"Processing: {k} - {len(v)} to be added") | |
for item in v: | |
my_note: genanki.Note = genanki.Note( | |
model=anki_model, | |
fields=[item, k] | |
) | |
anki_deck.add_note( | |
my_note | |
) | |
genanki.Package(anki_deck).write_to_file(out_file) | |
@click.command() | |
@click.argument('first_file', type=click.Path(exists=True)) | |
@click.option("--debug", default=False, help='Debug output', is_flag=True) | |
@click.option("--single-file", default=False, help='Single file only', is_flag=True) | |
@click.option("--generate-deck", default=True, help='Generate Anki deck', is_flag=True) | |
@click.option("--deck-title", default="AI Flashcards", help='Deck title') | |
@click.option("--env", default=".env", help='Path to .env file, default is .env') | |
@click.option('--help', | |
is_flag=True, | |
expose_value=False, | |
is_eager=False, | |
callback=print_help, | |
help="Print help message") | |
def createFlashcards(first_file: str, single_file: bool, generate_deck: bool, deck_title: str, env: str, debug: bool) -> None: | |
global _debug | |
env = dotenv.dotenv_values(env) | |
_debug = debug | |
base_dir: str = os.path.dirname(first_file) | |
file_base: str = os.path.basename(first_file) | |
file_stem: str = str.join('-', str.split(file_base, '-')[0:-1]) | |
llm: ChatOpenAI = setup_openai(env) | |
if single_file: | |
pdfs_to_process: list[str] = [first_file] | |
else: | |
pdfs_to_process: list[str] = list(Path(base_dir).glob(f"{file_stem}*.pdf")) | |
deck_contents: dict = {} | |
for pdf in pdfs_to_process: | |
print_debug(f"Processing: {str(pdf)}") | |
docs: list[Document] = load_pdf_and_split(str(pdf)) | |
retry_limit: int = 3 | |
error_flagged: bool = False | |
while retry_limit > 0: | |
try: | |
result: None = run_chain(llm, docs) | |
result_json: dict = json.loads(result['output_text']) | |
retry_limit = 0 | |
error_flagged = False | |
except Exception as e: | |
error_flagged = True | |
retry_limit -= 1 | |
print_debug(f"Error: {e}, retry counter decreased: {retry_limit}") | |
if debug: | |
write_json_to_file(result, f"{base_dir}/{Path(pdf).stem}.json") | |
if error_flagged: | |
print_debug(f"Error flagged for: {pdf}") | |
raise Exception("Error flagged, no retries left, exiting - if debug, see the JSON file for the last run.") | |
print_debug(f"--> Finished LLM Portion for : {pdf}, cloze pulled: {len(result_json)}") | |
metadata: str = f"Book: {str.join('-', str.split(pdf.name, '-')[0:-1])}, Chapter: {str.split(pdf.name, '-')[-1]}" | |
deck_contents[metadata] = result_json | |
print_debug(f"--> Finished processing: {pdf}") | |
if generate_deck: | |
print_debug(f"Generating Anki Deck: {base_dir}/{file_stem}.apkg") | |
generate_anki_deck(deck_contents, f"{base_dir}/{file_stem}.apkg", deck_title) | |
if __name__ == '__main__': | |
createFlashcards() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment