Created
May 26, 2017 10:54
-
-
Save jbg/971c04be74a3c11fcd7f6ac530515adb to your computer and use it in GitHub Desktop.
Similarity model example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import cPickle | |
__import__('__init__').init_triposo() | |
from sklearn.decomposition import TruncatedSVD, NMF | |
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer | |
from time import time | |
from norwegianblue import flags, mapreduce, store, base | |
from norwegianblue.status import Status | |
flags.pois_tokenized = flags.File('prod/pois_tokenized@24') | |
flags.pois_vectors = flags.File('tmp/pois_vector@24') | |
flags.model = flags.File('tmp/poi_sim_model.pickle') | |
flags.poi_record_limit = flags.Int(0) | |
flags.algorithm = flags.Enum(['SVD', 'NMF']) | |
flags.dimensions = flags.Int(default_value=100) | |
def enumerate_pois(cat): | |
status = Status() | |
count = 0 | |
for idx, (poi_id, poi) in enumerate(store.open(flags.pois_tokenized).items()): | |
if flags.poi_record_limit and flags.poi_record_limit == idx: | |
break | |
if poi['poicat'] == cat: | |
status.count('poi').report() | |
count += 1 | |
yield poi_id, poi | |
status.log('enumeration-done') | |
if count == 0: | |
raise ZeroDivisionError | |
def enumerate_docs(cat): | |
for poi_id, poi in enumerate_pois(cat): | |
yield poi['tokens'].split('|') | |
def main(): | |
pois_vectors = store.open(flags.pois_vectors, mode='w') | |
status = Status() | |
model = {} | |
for c in 'sleep', 'see', 'eatdrink': | |
print 'Extracting features from the pois for', c | |
t0 = time() | |
vectorizer = TfidfVectorizer(#ngram_range=(1,2), | |
min_df=10, | |
tokenizer=lambda tokens:tokens, | |
lowercase=False, | |
use_idf=True) | |
print 'fitting the model' | |
try: | |
corpus = vectorizer.fit_transform(enumerate_docs(c)) | |
except ZeroDivisionError: | |
print 'no documents.' | |
continue | |
print('done in %fs' % (time() - t0)) | |
print('n_samples: %d, n_features: %d' % corpus.shape) | |
print('Performing dimensionality reduction using %s' % flags.algorithm) | |
t0 = time() | |
if flags.algorithm == 'SVD': | |
dimenstion_reducer = TruncatedSVD(n_components=flags.dimensions) | |
elif flags.algorithm == 'NMF': | |
dimenstion_reducer = NMF(n_components=flags.dimensions) | |
transformed = dimenstion_reducer.fit_transform(corpus) | |
print('done in %fs' % (time() - t0)) | |
if flags.algorithm == 'SVD': | |
try: | |
explained_variance = dimenstion_reducer.explained_variance_ratio_.sum() | |
print 'Explained variance of the SVD step: {}%'.format(int(explained_variance * 100)) | |
except AttributeError: | |
pass | |
elif flags.algorithm == 'NMF': | |
print 'Reconstruction error %2.f' % dimenstion_reducer.reconstruction_err_ | |
vectorizer.tokenizer = None | |
model[c] = (vectorizer, dimenstion_reducer) | |
for (poi_id, poi), poi_vector in zip(enumerate_pois(c), transformed): | |
del poi['tokens'] | |
status.count('poi-out').report() | |
pois_vectors[poi_id] = dict(poi, vector=list(poi_vector)) | |
with file(flags.model, 'w') as f: | |
cPickle.dump(model, f, -1) | |
if __name__ == '__main__': | |
flags.non_flag_components() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment