Skip to content

Instantly share code, notes, and snippets.

@jessstringham
Created June 20, 2023 04:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jessstringham/2a787140dd09cc76f3486a5858a33ff2 to your computer and use it in GitHub Desktop.
Save jessstringham/2a787140dd09cc76f3486a5858a33ff2 to your computer and use it in GitHub Desktop.
Samples for a session at ITP camp 2023
from collections import defaultdict
import random
WINDOW_SIZE = 2
with open("alice.txt") as f:
full_text = f.read()
paragraphs = [
paragraph.strip().split() + ["(END)"] # remove trailing whitespace, split on whitespace, add end tokens
for paragraph in full_text.split("\n")
if len(paragraph.split()) > WINDOW_SIZE + 1 # remove empty paragraphs
]
# first_probs will end up like
# {
# ("This", "will"): 12
# }
first_counts = defaultdict(int)
# next_probs will end up like
# {
# ("this", "will"): {"end": 11, "look": 12}
# }
next_counts = defaultdict(lambda: defaultdict(int))
for paragraph in paragraphs:
first_counts[tuple(paragraph[:WINDOW_SIZE])] += 1
for i in range(len(paragraph) - WINDOW_SIZE):
window = tuple(paragraph[i:i + WINDOW_SIZE])
next = paragraph[i + WINDOW_SIZE]
# count of ((w1, w2), w3)
next_counts[window][next] += 1
def roll_weighted_die(faces, weights):
value = random.choices(
list(faces), # choose the faces of the die
weights=list(weights) # and use the counts as the weights of the die
)
return value[0]
def first():
return list(roll_weighted_die(first_counts.keys(), first_counts.values()))
def next(window):
probs = next_counts[tuple(window)]
if len(probs) == 0:
# choose something random
probs = random.choice(next_counts.values())
return roll_weighted_die(probs.keys(), probs.values())
# now generate!
paragraph = first()
while paragraph[-1] != "(END)":
n = next(paragraph[-WINDOW_SIZE:])
paragraph.append(n)
print(' '.join(paragraph[:-1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment