Last active
October 13, 2016 07:37
-
-
Save forcemax/a6b5885fea859b43763f7712e82d546b to your computer and use it in GitHub Desktop.
YFCC100M tag prediction clean dataset python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import re | |
import collections | |
import urllib.parse | |
from time import time | |
from multiprocessing import Pool | |
KEEPWORDS_FILE = "keepwords.txt" | |
TRAIN_DATASET_DIR = "../yfcc100m/" | |
CLEANED_TRAIN_FILE_WRITE_INTERVAL = 500000 | |
KEEPWORDS_THRESHOLD = 100 | |
WORDCOUNT_WORKERS = 2 | |
CLEAN_WORKERS = 6 | |
def clean_str(string): | |
""" | |
Tokenization/string cleaning for all datasets except for SST. | |
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py | |
""" | |
# string = re.sub(r"[^A-Za-z0-9(),!?\'\`_]", " ", string) | |
string = re.sub("<.*?>", " ", string) | |
string = re.sub(r"\'s", " \'s", string) | |
string = re.sub(r"\'ve", " \'ve", string) | |
string = re.sub(r"n\'t", " n\'t", string) | |
string = re.sub(r"\'re", " \'re", string) | |
string = re.sub(r"\'d", " \'d", string) | |
string = re.sub(r"\'ll", " \'ll", string) | |
string = re.sub(r",", " , ", string) | |
string = re.sub(r"!", " ! ", string) | |
string = re.sub(r"\(", " \( ", string) | |
string = re.sub(r"\)", " \) ", string) | |
string = re.sub(r"\?", " \? ", string) | |
string = re.sub(r"\n", " ", string) | |
string = re.sub(r"\r", " ", string) | |
string = re.sub(r"\s{2,}", " ", string) | |
return string.strip().lower() | |
def wordcount_worker(path): | |
print('wordcount worker started : %s' % path) | |
wordcount = collections.Counter() | |
count = 0 | |
words = [] | |
with open(path) as f: | |
for line in f: | |
count += 1 | |
sline = line.split('\t') | |
# user tag | |
words += [k.strip() for k in clean_str(urllib.parse.unquote(sline[8])).replace('+', '_').split(',') if k.strip() != ''] | |
# title & description | |
words += [k.strip() for k in clean_str(urllib.parse.unquote_plus(sline[6] + ' ' + sline[7])).split() if k.strip() != ''] | |
if count % 100000 == 0: | |
try: | |
words[:] = (v for v in words if v != '') | |
except ValueError: | |
pass | |
wordcount.update(words) | |
words[:] = [] | |
if count % 1000000 == 0: | |
print('%s : line %d passed' % (path, count)) | |
print('wordcount worker finished : %s' % path) | |
return wordcount | |
def clean_data(tags, titles, descriptions): | |
string = "" | |
for t, ti, desc in zip(tags, titles, descriptions): | |
t_tags = clean_str(urllib.parse.unquote(t)).replace('+', '_').split(',') | |
t_tags = [k.strip() for k in t_tags if k.strip() in keepwords] | |
t_tags = ['__label__'+k for k in t_tags] | |
t_titles = clean_str(urllib.parse.unquote_plus(ti)) | |
t_titles = [k.strip() for k in t_titles.split() if k.strip() in keepwords] | |
t_descriptions = clean_str(urllib.parse.unquote_plus(desc)) | |
t_descriptions = [k.strip() for k in t_descriptions.split() if k.strip() in keepwords] | |
if len(t_titles) < 1 and len(t_descriptions) < 1: | |
continue | |
if len(t_tags) < 1: | |
continue | |
if len(t_tags) == 1 and t_tags[0] == '__label__': | |
continue | |
string += "%s %s %s\n" % (' '.join(t_tags), ' '.join(t_titles), ' '.join(t_descriptions)) | |
return string | |
def clean_worker(path): | |
print("clean worker started : %s" % path) | |
tags, titles, descriptions = ([] for i in range(3)) | |
count = total_count = 0 | |
with open(path + '_cleaned', 'w') as w: | |
with open(path) as f: | |
for line in f: | |
count += 1 | |
total_count += 1 | |
sline = line.split('\t') | |
titles.append(sline[6]) | |
descriptions.append(sline[7]) | |
tags.append(sline[8]) | |
if count == CLEANED_TRAIN_FILE_WRITE_INTERVAL: | |
w.write("%s" % clean_data(tags, titles, descriptions)) | |
print("%s line processed : %d" % (path, total_count)) | |
tags[:], titles[:], descriptions[:] = ([] for i in range(3)) | |
count = 0 | |
if len(tags) > 0: | |
w.write("%s" % clean_data(tags, titles, descriptions)) | |
print("clean worker finished : %s" % path) | |
keepwords = set() | |
if __name__ == '__main__': | |
if not os.path.exists(KEEPWORDS_FILE): | |
## calculate all word count | |
t0 = time() | |
files = [] | |
for (dirpath, dirnames, filenames) in os.walk(TRAIN_DATASET_DIR): | |
for filename in filenames: | |
if "_dataset" in filename and "_cleaned" not in filename: | |
files.append(os.path.join(dirpath, filename)) | |
wordcount = collections.Counter() | |
with Pool(processes = WORDCOUNT_WORKERS) as pool: | |
jobs = pool.imap_unordered(wordcount_worker, files) | |
for res in jobs: | |
wordcount.update(res) | |
ttt = time() - t0 | |
print("duration : %0.3fs" % ttt) | |
## set keep words | |
t0 = time() | |
print("Set keep words...") | |
for k in wordcount.keys(): | |
if wordcount[k] >= KEEPWORDS_THRESHOLD: | |
keepwords.add(k) | |
wordcount = None | |
print("keep words : %d ( count >= %d )" % (len(keepwords), KEEPWORDS_THRESHOLD)) | |
ttt = time() - t0 | |
print("duration : %0.3fs" % ttt) | |
## write keep words to file | |
with open(KEEPWORDS_FILE, "w") as w: | |
for word in keepwords: | |
w.write("%s\n" % word) | |
with open(KEEPWORDS_FILE) as f: | |
for line in f: | |
sline = line.split() | |
for s in sline: | |
keepwords.add(s) | |
## keep keepwords and remove others | |
files = [] | |
for (dirpath, dirnames, filenames) in os.walk(TRAIN_DATASET_DIR): | |
for filename in filenames: | |
if "_dataset" in filename and "_cleaned" not in filename: | |
files.append(os.path.join(dirpath, filename)) | |
with Pool(processes=CLEAN_WORKERS) as pool: | |
jobs = pool.imap_unordered(clean_worker, files) | |
for res in jobs: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment