Skip to content

Instantly share code, notes, and snippets.

@TheDarkTrumpet
Created January 2, 2024 18:03
Show Gist options
  • Save TheDarkTrumpet/431852be731df2c783e7294107fad25a to your computer and use it in GitHub Desktop.
Save TheDarkTrumpet/431852be731df2c783e7294107fad25a to your computer and use it in GitHub Desktop.
A script/program that can take chapters of a book, run LangChain summarization against it, and create flashcards. Please see the linked web article (in comments) for details
"""
create_flashcards.py
This script is used to generate flashcards from a PDF file. It uses the LLM model to generate a list of
definitions and concepts from a PDF file. It then takes this list and generates a
JSON file that can be used to generate an Anki deck.
This script is a companion to the blog post: https://thedarktrumpet.com/programming/2024/01/02/generative-ai-flashcards/
This script relies on a .env file to provide the following variables:
HOST - The host for the OpenAI API
KEY - The API key for the OpenAI API
MODEL - The model to use for the OpenAI API
MAX_TOKENS - The maximum number of tokens to use for the OpenAI API
Please note that this script has NOT been tested with OpenAI, but with my own locally-hosted model. If using with OpenAI,
you may need to adjust the MAX_TOKENS and associated parameters related to stuff-size to get it to work.
"""
import os
from pathlib import Path
from typing import Any
from xml.dom.minidom import Document
from langchain.prompts import PromptTemplate
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain
import pprint
import click
import dotenv
import json
import genanki
import random as rnd
_debug: bool = False
anki_model: genanki.Model = genanki.Model(
rnd.randint(10000,100000000), # This is a random model ID
'ai-close-model',
fields=[
{'name': 'Text'},
{'name': 'Extra'}
],
templates=[
{
'name': 'ai-card',
'qfmt': '{{cloze:Text}}',
'afmt': '{{cloze:Text}}<br>{{Extra}}',
},
],
css=""".card {
font-family: arial;
font-size: 20px;
text-align: center;
color: black;
background-color: white;
}
.cloze {
font-weight: bold;
color: blue;
}
.nightMode .cloze {
color: lightblue;
}
""",
model_type=genanki.Model.CLOZE,
)
def print_help(ctx, opts, args):
if args is False:
return
click.echo(ctx.get_help())
ctx.exit()
def print_debug(message):
if not _debug:
return
if isinstance(message, list):
ix = 0
for msg in message:
print_debug(f"{ix} => {msg}")
ix += 1
return
print(f"DEBUG: {message}")
def setup_openai(env) -> ChatOpenAI:
llm: ChatOpenAI = ChatOpenAI(openai_api_base=env['HOST'], openai_api_key=env['KEY'], model_name=env['MODEL'], temperature=0.4, max_tokens=int(env['MAX_TOKENS']))
# TikToken - https://github.com/openai/tiktoken
# Need to set TIKTOKEN_CACHE_DIR to avoid the download (which is blocked by the firewall), see https://stackoverflow.com/questions/76106366/how-to-use-tiktoken-in-offline-mode-computer
os.environ["TIKTOKEN_CACHE_DIR"] = "no_git"
assert os.path.exists(os.path.join(os.environ["TIKTOKEN_CACHE_DIR"], "9b5ad71b2ce5302211f9c61530b329a4922fc6a4"))
return llm
def load_pdf_and_split(filename: str) -> list[Document]:
print_debug(f"Loading PDF: {filename}")
pdf_loader: PyPDFLoader = PyPDFLoader(filename)
docs: list[Document] = pdf_loader.load_and_split(text_splitter=RecursiveCharacterTextSplitter(chunk_size=12288, chunk_overlap=0))
print_debug(f"Number of chunks: {len(docs)}")
ix = 0
for doc in docs:
print_debug(f"Length of chunk {ix}: {len(doc.page_content)}")
ix += 1
return docs
def get_map_chain(llm) -> LLMChain:
map_template: str = """The following is a set of documents
{docs}
Based on this list of docs, please pick out the major concepts, TERMS, DEFINITIONS, and ACRONYMS that are important in the document.
Do not worry about historical context (when something was introduced or implemented). Ignore anything that looks like source code.
Helpful Answer:"""
map_prompt: PromptTemplate = PromptTemplate.from_template(map_template)
map_chain: LLMChain = LLMChain(llm=llm, prompt=map_prompt)
return map_chain
def get_reduce_document_chain(llm: ChatOpenAI) -> LLMChain:
# Reduce
reduce_template: str = """The following is set of definitions and concepts:
{docs}
Take these and distill it into a final, consolidated list of at least twenty (20) definitions and concepts, in the format of cloze sentences. The goal of this is that these sentences
will be inserted into ANKI. Please provide the final list as a FULLY VALID JSON LIST, NOT a dictionary!
An example of what I'm requesting, for output, should be formatted similar be the following:
["A {{{{c1::cat}}}} is a {{{{c2::furry}}}} animal that {{{{c3::meows}}}}.", "A {{{{c1::dog}}}} is a {{{{c2::furry}}}} animal that {{{{c3::barks}}}}, "a {{{{c1::computer}}}} is a machine that computes."]
Helpful Answer:"
"""
reduce_prompt: PromptTemplate = PromptTemplate.from_template(reduce_template)
# Run chain
reduce_chain: LLMChain = LLMChain(llm=llm, prompt=reduce_prompt)
# Takes a list of documents, combines them into a single string, and passes this to an LLMChain
combine_documents_chain: StuffDocumentsChain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs")
# Combines and iteravely reduces the mapped documents
reduce_documents_chain: ReduceDocumentsChain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=combine_documents_chain,
token_max=15000)
return reduce_documents_chain
def run_chain(llm: ChatOpenAI, docs: list) -> None:
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=get_map_chain(llm),
reduce_documents_chain=get_reduce_document_chain(llm),
document_variable_name="docs",
return_intermediate_steps=False,
return_map_steps=True
)
result: dict[str, Any] = map_reduce_chain(docs)
return result
def write_json_to_file(result: json, out_file: str) -> None:
with open(out_file, "w") as f:
f.write(pprint.pformat(result))
def generate_anki_deck(contents: dict[str, list], out_file: str, deck_title: str) -> None:
anki_deck: genanki.Deck = genanki.Deck(
rnd.randint(10000,100000000),
deck_title)
for k,v in contents.items():
print_debug(f"Processing: {k} - {len(v)} to be added")
for item in v:
my_note: genanki.Note = genanki.Note(
model=anki_model,
fields=[item, k]
)
anki_deck.add_note(
my_note
)
genanki.Package(anki_deck).write_to_file(out_file)
@click.command()
@click.argument('first_file', type=click.Path(exists=True))
@click.option("--debug", default=False, help='Debug output', is_flag=True)
@click.option("--single-file", default=False, help='Single file only', is_flag=True)
@click.option("--generate-deck", default=True, help='Generate Anki deck', is_flag=True)
@click.option("--deck-title", default="AI Flashcards", help='Deck title')
@click.option("--env", default=".env", help='Path to .env file, default is .env')
@click.option('--help',
is_flag=True,
expose_value=False,
is_eager=False,
callback=print_help,
help="Print help message")
def createFlashcards(first_file: str, single_file: bool, generate_deck: bool, deck_title: str, env: str, debug: bool) -> None:
global _debug
env = dotenv.dotenv_values(env)
_debug = debug
base_dir: str = os.path.dirname(first_file)
file_base: str = os.path.basename(first_file)
file_stem: str = str.join('-', str.split(file_base, '-')[0:-1])
llm: ChatOpenAI = setup_openai(env)
if single_file:
pdfs_to_process: list[str] = [first_file]
else:
pdfs_to_process: list[str] = list(Path(base_dir).glob(f"{file_stem}*.pdf"))
deck_contents: dict = {}
for pdf in pdfs_to_process:
print_debug(f"Processing: {str(pdf)}")
docs: list[Document] = load_pdf_and_split(str(pdf))
retry_limit: int = 3
error_flagged: bool = False
while retry_limit > 0:
try:
result: None = run_chain(llm, docs)
result_json: dict = json.loads(result['output_text'])
retry_limit = 0
error_flagged = False
except Exception as e:
error_flagged = True
retry_limit -= 1
print_debug(f"Error: {e}, retry counter decreased: {retry_limit}")
if debug:
write_json_to_file(result, f"{base_dir}/{Path(pdf).stem}.json")
if error_flagged:
print_debug(f"Error flagged for: {pdf}")
raise Exception("Error flagged, no retries left, exiting - if debug, see the JSON file for the last run.")
print_debug(f"--> Finished LLM Portion for : {pdf}, cloze pulled: {len(result_json)}")
metadata: str = f"Book: {str.join('-', str.split(pdf.name, '-')[0:-1])}, Chapter: {str.split(pdf.name, '-')[-1]}"
deck_contents[metadata] = result_json
print_debug(f"--> Finished processing: {pdf}")
if generate_deck:
print_debug(f"Generating Anki Deck: {base_dir}/{file_stem}.apkg")
generate_anki_deck(deck_contents, f"{base_dir}/{file_stem}.apkg", deck_title)
if __name__ == '__main__':
createFlashcards()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment