Last active
November 14, 2023 10:13
-
-
Save bbkane/8c9de52caa43c87a35b6ae21526e397f to your computer and use it in GitHub Desktop.
MarkovChain in SQL + Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import sqlite3 | |
CREATE_TABLES_SQL = """ | |
CREATE TABLE IF NOT EXISTS word | |
( | |
id INTEGER NOT NULL, | |
word TEXT UNIQUE NOT NULL, | |
number INTEGER DEFAULT 1, | |
PRIMARY KEY (id) | |
); | |
CREATE TABLE IF NOT EXISTS word_pair | |
( | |
first_word_id INTEGER NOT NULL, | |
second_word_id INTEGER NOT NULL, | |
number INTEGER DEFAULT 1, | |
PRIMARY KEY (first_word_id, second_word_id), | |
FOREIGN KEY (first_word_id) REFERENCES word(id), | |
FOREIGN KEY (second_word_id) REFERENCES word(id) | |
); | |
""" | |
def insert_word(sql_conn, sql_cur, word): | |
# https://stackoverflow.com/a/3661644/2958070 | |
UPDATE_WORD_SQL = """ | |
UPDATE word SET number = number + 1 WHERE word = ? | |
""" | |
# NOTE: this is probably not thread safe | |
INSERT_WORD_SQL = """ | |
INSERT INTO word (word, number) SELECT ?, 1 WHERE (SELECT changes() = 0); | |
""" | |
SELECT_WORD_SQL = """ | |
SELECT id FROM word WHERE word = ? | |
""" | |
sql_cur.execute(UPDATE_WORD_SQL, (word,)) | |
sql_cur.execute(INSERT_WORD_SQL, (word,)) | |
sql_conn.commit() | |
sql_cur.execute(SELECT_WORD_SQL, (word,)) | |
return sql_cur.fetchone()[0] | |
def insert_word_pair(sql_conn, sql_cur, first_word_id, second_word_id): | |
# # https://stackoverflow.com/a/3661644/2958070 | |
UPDATE_WORD_PAIR_SQL = """ | |
UPDATE word_pair SET number = number + 1 WHERE first_word_id = ? and second_word_id = ? | |
""" | |
INSERT_WORD_PAIR_SQL = """ | |
INSERT INTO word_pair (first_word_id, second_word_id, number) SELECT ?, ?, 1 WHERE (SELECT changes() = 0); | |
""" | |
data = (first_word_id, second_word_id) | |
sql_conn.execute(UPDATE_WORD_PAIR_SQL, data) | |
sql_conn.execute(INSERT_WORD_PAIR_SQL, data) | |
sql_conn.commit() | |
# TODO: there has got to be a better way to do this... | |
def get_file_by_word(fp): | |
for line in fp: | |
for word in line.split(): | |
yield word | |
def get_file_by_two_words(fp): | |
file_by_word = get_file_by_word(fp) | |
first_word = next(file_by_word) | |
while True: | |
second_word = next(file_by_word) | |
yield (first_word, second_word) | |
first_word = second_word | |
# TODO: add special case for the last word if it's an odd number | |
def main(): | |
conn = sqlite3.connect('MarkovChain.sqlite3') | |
with contextlib.closing(conn): | |
cur = conn.cursor() | |
cur.executescript(CREATE_TABLES_SQL) | |
conn.commit() | |
with open('./war_and_peace.txt') as fp: | |
for first_word, second_word in get_file_by_two_words(fp): | |
first_word_id = insert_word(conn, cur, first_word) | |
second_word_id = insert_word(conn, cur, second_word) | |
insert_word_pair(conn, cur, first_word_id, second_word_id) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment