Skip to content

Instantly share code, notes, and snippets.

@FlyingFathead
Last active June 21, 2024 00:15
Show Gist options
  • Save FlyingFathead/207e5040f3c53ddda5ba38cbd9f88ede to your computer and use it in GitHub Desktop.
Save FlyingFathead/207e5040f3c53ddda5ba38cbd9f88ede to your computer and use it in GitHub Desktop.
Word letter counting dataset generator for LLM training
#!/usr/bin/env python3
# word_letter_counter_for_llm_datasets.py
"""
A training set generator to assist an LLM in actually being able to count
the letters in a given word. A model that can't count letters in a word isn't
usable for critical tasks; incorrect letter counts lead to compounding mistakes.
Outputs in JSON, can also use the NLTK corpus for a word dictionary, offering
a quick way to create a massive letter counting dataset for different words.
The code itself is completely unoptimized, but hey, such is life.
Requires `nltk`; install with: `pip install nltk`
[ Dedicated to the LLM industry giants who still don't seem to be getting
the letter counts in a word right even in their SoTA flagship models. ]
"The problem is in the tokenizer/model structure/x/y/z"
Maybe, but try this steering set nonetheless.
`--num_words` allows the user to specify the number of words to process
(if going through the nltk dictionary)
`--word` allows the user to input their own specific word to process.
`--include_nonexistent` also includes a count of non-existent characters ("0").
`--exclude_nonexistent` exclude the zero-mentions
`--include_non_a_z` includes Unicode characters outside the a-z range.
`--exclude_non_a_z` same as above but the other way around. :-d
LICENSE
=======
Go ahead and use it in anything you want. Who cares?
More of my AI & other code junk at: https://github.com/FlyingFathead
(c) FlyingFathead 2024
"""
import json
from nltk.corpus import words
import nltk
import argparse
import sys
import string
import unicodedata
def download_nltk_words():
"""Download the NLTK words corpus if it's not already available."""
try:
nltk.download('words')
except Exception as e:
print(f"Error downloading NLTK words corpus: {e}", file=sys.stderr)
exit(1)
def is_alphabetic(char):
"""Check if a character is alphabetic according to Unicode standard."""
return unicodedata.category(char).startswith('L')
def count_and_highlight(word, char):
"""Count occurrences of `char` in `word` (case-insensitive) and highlight them."""
word_lower = word.lower()
char_lower = char.lower()
count = word_lower.count(char_lower)
highlighted_word = ''.join([f'[{c}]' if c.lower() == char_lower else c for c in word])
return count, highlighted_word
def generate_training_data(lexicon, include_nonexistent, include_non_a_z):
"""Generate training data with character counts and highlights for each word in `lexicon`."""
training_data = {}
if include_non_a_z:
alphabet = set(char for char in map(chr, range(0x110000)) if is_alphabetic(char))
else:
alphabet = set(string.ascii_lowercase)
for word in lexicon:
unique_chars = set(word.lower())
word_data = {'total_count': len(word)}
for char in unique_chars:
count, highlighted = count_and_highlight(word, char)
word_data[char] = {'count': count, 'highlighted': highlighted}
if include_nonexistent:
for char in alphabet - unique_chars:
word_data[char] = {'count': 0, 'highlighted': word}
training_data[word] = word_data
return training_data
def process_single_word(word, include_nonexistent, include_non_a_z):
"""Process a single user-provided word."""
unique_chars = set(word.lower())
word_data = {'total_count': len(word)}
for char in unique_chars:
count, highlighted = count_and_highlight(word, char)
word_data[char] = {'count': count, 'highlighted': highlighted}
if include_nonexistent:
if include_non_a_z:
alphabet = set(char for char in map(chr, range(0x110000)) if is_alphabetic(char))
else:
alphabet = set(string.ascii_lowercase)
for char in alphabet - unique_chars:
word_data[char] = {'count': 0, 'highlighted': word}
return word_data
def main():
# Set up argument parser
parser = argparse.ArgumentParser(description='Generate character count and highlight data for words.')
parser.add_argument('--num_words', type=int, default=1000, help='Number of words to process from the lexicon')
parser.add_argument('--word', type=str, help='A specific word to process')
parser.add_argument('--output_file', type=str, help='File to save the generated training data')
# Arguments to include/exclude nonexistent characters with comments for clarity
parser.add_argument('--include_nonexistent', dest='include_nonexistent', action='store_true', help='Include characters with zero occurrences')
parser.add_argument('--exclude_nonexistent', dest='include_nonexistent', action='store_false', help='Do not include characters with zero occurrences')
# Set the default to include nonexistent characters
parser.set_defaults(include_nonexistent=True)
# Arguments to include/exclude non-a-z characters with comments for clarity
parser.add_argument('--include_non_a_z', dest='include_non_a_z', action='store_true', help='Include characters outside the a-z range')
parser.add_argument('--exclude_non_a_z', dest='include_non_a_z', action='store_false', help='Do not include characters outside the a-z range')
# Set the default to exclude non-a-z characters
parser.set_defaults(include_non_a_z=False)
args = parser.parse_args()
if args.word:
# Process a single user-provided word
training_data = {args.word: process_single_word(args.word, args.include_nonexistent, args.include_non_a_z)}
else:
# Download the NLTK words corpus if needed and use it as a larger lexicon example
download_nltk_words()
try:
lexicon = words.words()[:args.num_words] # Use the number of words specified by the user
except LookupError:
print("The NLTK words corpus is not available. Please download it using nltk.download('words').", file=sys.stderr)
exit(1)
training_data = generate_training_data(lexicon, args.include_nonexistent, args.include_non_a_z)
# Output the generated training data
if args.output_file:
try:
with open(args.output_file, 'w') as f:
json.dump(training_data, f, indent=2)
except IOError as e:
print(f"Error writing to file {args.output_file}: {e}", file=sys.stderr)
else:
print(json.dumps(training_data, indent=2))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment