Last active
August 29, 2015 14:00
-
-
Save bbengfort/11292903 to your computer and use it in GitHub Desktop.
Cloze Analysis using NLTK
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import sys | |
import nltk | |
import argparse | |
import unicodecsv as csv | |
from operator import itemgetter | |
PATH = os.path.normpath(os.path.join(os.path.dirname(__file__), 'cloze_output.txt')) | |
DESCRIPTION = "Analyze adj-noun bigram distributions from cloze data and the Penn Treebank" | |
EPILOG = "Please email Ben for any problems or concerns with this code." | |
VERSION = "1.3" | |
class ClozeAnalysis(object): | |
def __init__(self, cloze_output=PATH): | |
self.path = cloze_output | |
self.data = [] | |
self.read() | |
def read(self): | |
""" | |
Reads the cloze_output.txt file and outputs pairs. | |
COLUMNS are as follows: | |
1. Subject ID | |
2. Condition (filter?) | |
3. Adjective Prompt | |
4. Noun supplied by subject | |
5. Some sort of encoding | |
""" | |
with open(self.path, 'r') as f: | |
reader = csv.reader(f, delimiter='\t') | |
for row in reader: | |
self.data.append((row[2].lower(), row[3].lower())) | |
## Data access functions | |
def cloze_unigrams(self): | |
for pair in self.data: | |
for item in pair: | |
yield item | |
def cloze_bigrams(self): | |
for pair in self.data: | |
yield pair | |
def cloze_adjectives(self): | |
for pair in self.data: | |
yield pair[0] | |
def cloze_nouns(self): | |
for pair in self.data: | |
yield pair[1] | |
def treebank_unigrams(self): | |
for word in nltk.corpus.treebank.words(): | |
yield word.lower() | |
def treebank_bigrams(self, adjective_filter=False): | |
""" | |
If filter is True, only yield bigrams that start with an adjective | |
""" | |
if adjective_filter: | |
adjectives = set(self.cloze_adjectives()) | |
for bigram in nltk.bigrams(self.treebank_unigrams()): | |
if adjective_filter and bigram[0] not in adjectives: | |
continue | |
yield bigram | |
## Computing Cloze Probabilities | |
def adjnoun_distribution(self, corpus='cloze'): | |
if corpus == 'cloze': | |
return nltk.FreqDist(self.cloze_bigrams()) | |
if corpus == 'treebank': | |
return nltk.FreqDist(self.treebank_bigrams()) | |
raise Exception("No corpus named '%s' found." % corpus) | |
def conditional_distribution(self, corpus='cloze'): | |
if corpus == 'cloze': | |
return nltk.ConditionalFreqDist(self.cloze_bigrams()) | |
if corpus == 'treebank': | |
return nltk.ConditionalFreqDist(self.treebank_bigrams()) | |
raise Exception("No corpus named '%s' found." % corpus) | |
def cloze_rows(self): | |
dist = self.conditional_distribution(corpus='cloze') | |
for key,val in sorted(dist.items(), key=itemgetter(0)): | |
for k, v in val.items(): | |
yield dict(zip(('adj', 'noun', 'freq'), (key, k, v))) | |
def treebank_rows(self): | |
dist = self.conditional_distribution(corpus='treebank') | |
adjs = set(self.cloze_adjectives()) | |
for key,val in sorted(dist.items(), key=itemgetter(0)): | |
if key in adjs: | |
for k, v in val.items(): | |
yield dict(zip(('adj', 'noun', 'freq'), (key, k, v))) | |
## Comparing Cloze to Penn Treebank | |
def compare_unigrams(self): | |
clozedist = nltk.FreqDist(self.cloze_unigrams()) | |
penntdist = nltk.FreqDist(self.treebank_unigrams()) | |
for token in clozedist.keys(): | |
yield token, clozedist[token], penntdist[token] | |
def compare_bigrams(self): | |
clozedist = nltk.FreqDist(self.cloze_bigrams()) | |
penntdist = nltk.FreqDist(self.treebank_bigrams()) | |
for token in clozedist.keys(): | |
yield token, clozedist[token], penntdist[token] | |
def intersection(self, pos='both'): | |
""" | |
Reports the intersection of nouns or adjectives (or both) in the | |
two corpora- e.g. what is shared among them. | |
""" | |
if pos == 'both': | |
return set(self.cloze_unigrams()) & set(self.treebank_unigrams()) | |
if pos == 'adj': | |
return set(self.cloze_adjectives()) & set(self.treebank_unigrams()) | |
if pos == 'noun': | |
return set(self.cloze_nouns()) & set(self.treebank_unigrams()) | |
if pos == 'bigram': | |
return set(self.cloze_bigrams()) & set(self.treebank_bigrams()) | |
## Helper functions | |
def pprint(self, corpus='cloze'): | |
output = [] | |
dist = self.conditional_distribution(corpus=corpus) | |
for key,val in sorted(dist.items(), key=itemgetter(0)): | |
output.append("%s:" % key) | |
for k,v in val.items(): | |
output.append(" %2i %s" % (v,k)) | |
return "\n".join(output) | |
def info(self): | |
return ( | |
"%i cloze words with %i in common with the treebank\n" | |
"%i cloze adjectives with %i in common with treebank\n" | |
"%i cloze nouns with %i in common with treebank\n" | |
"%i adj-noun bigrams with %i in common with treebank\n" | |
"%i treebank bigrams start with adjective in cloze adjectives" | |
) % ( | |
len(set(self.cloze_unigrams())), len(self.intersection('both')), | |
len(set(self.cloze_adjectives())), len(self.intersection('adj')), | |
len(set(self.cloze_nouns())), len(self.intersection('noun')), | |
len(set(self.cloze_bigrams())), len(self.intersection('bigram')), | |
len(set(self.treebank_bigrams(True))), | |
) | |
def __str__(self): | |
return self.pprint() | |
### Output functions | |
def output(self, stream, corpus='cloze', format='csv'): | |
jump = { | |
'csv': self.output_csv, | |
'txt': self.output_txt, | |
'info': self.output_info, | |
} | |
writer = jump[format] | |
writer(stream, corpus=corpus) | |
def output_txt(self, stream, corpus='cloze'): | |
stream.write(self.pprint(corpus=corpus)) | |
stream.write("\n") | |
def output_csv(self, stream, corpus='cloze'): | |
""" | |
Output frequencies of nouns per adjective to compute cloze. | |
""" | |
jump = { | |
'corpus': self.cloze_rows, | |
'treebank': self.treebank_rows, | |
} | |
rows = jump[corpus] | |
writer = csv.DictWriter(stream, fieldnames=('adj', 'noun', 'freq'), delimiter='\t') | |
writer.writeheader() | |
for row in rows(): | |
writer.writerow(row) | |
def output_info(self, stream, corpus='cloze'): | |
stream.write(self.info()) | |
stream.write('\n') | |
def main(*argv): | |
parser = argparse.ArgumentParser(description=DESCRIPTION, epilog=EPILOG, version=VERSION) | |
parser.add_argument('cloze_output', default=PATH, nargs='?', help="The path to the cloze distribution file.") | |
parser.add_argument('-c', '--corpus', type=str, default='cloze', choices=('cloze', 'treebank'), help='Corpus to perform work on.') | |
parser.add_argument('-o', '--outpath', metavar='PATH', type=argparse.FileType('w'), default=sys.stdout, help='Where to write out the results of analysis.') | |
parser.add_argument('-f', '--format', type=str, default='txt', choices=('csv', 'txt', 'info'), help='Format of the output.') | |
args = parser.parse_args() | |
analyzer = ClozeAnalysis(args.cloze_output) | |
analyzer.output(args.outpath, corpus=args.corpus, format=args.format) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment