Last active
August 29, 2015 14:14
-
-
Save evz/5cfddf0a9240b3d403ab to your computer and use it in GitHub Desktop.
Review Prediction machine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import rlr | |
from collections import OrderedDict | |
class ReviewMachine(object): | |
def __init__(self, entity_examples): | |
""" | |
Entity examples should be a dict where the key is the entity_id | |
and the value is a dict like so: | |
{"<entity_id>": { | |
"label": None, # None by default, will be labelled either 1 or 0 based on user input | |
"attributes": [], # 2 member list: length of cluster and max confidence score from cluster | |
"score": 1.0 # Calculated on the fly by mutiplying attributes by learned weight | |
} | |
} | |
Score is used to sort entities on the fly | |
""" | |
self.examples = entity_examples | |
self.weight = None | |
def label(self, entity_id, label): | |
self.examples[entity_id]['label'] = label | |
labels = [d['label'] for d in self.examples.values() \ | |
if d['label'] is not None] | |
examples = [d['attributes'] for d in self.examples.values() \ | |
if d['label'] is not None] | |
labels = numpy.array(labels, | |
dtype=numpy.int32) | |
examples = numpy.array(examples, | |
dtype=numpy.float32) | |
self.weight = rlr.lr(labels, examples, 0.1) | |
return self.weight | |
def _score(self): | |
weights, bias = self.weight | |
examples = [d['attributes'] for d in self.examples.values()] | |
scores = numpy.dot(examples, weights) | |
scores = numpy.exp(scores + bias) / ( 1 + numpy.exp(scores + bias) ) | |
entity_ids = [k for k,v in self.examples.items()] | |
for idx, entity_id in enumerate(entity_ids): | |
self.examples[entity_id]['score'] = scores.tolist()[idx] | |
def _sort(self): | |
self._score() | |
self.examples = OrderedDict(sorted(self.examples.items(), key=lambda x: x[1]['score'])) | |
def predict(self, example): | |
weights, bias = self.weight | |
score = numpy.dot(example, weights) | |
score = numpy.exp(score + bias) / ( 1 + numpy.exp(score + bias) ) | |
return score | |
def get_next(self): | |
self._sort() | |
for entity_id, example in self.examples.items(): | |
if example['label'] is None: | |
return entity_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from api.database import init_engine | |
from api.app_config import DB_CONN | |
import random | |
init_engine(DB_CONN) | |
if __name__ == "__main__": | |
from api.utils.review_machine import ReviewMachine | |
from api.utils.helpers import getCluster | |
from api.database import worker_session | |
engine = worker_session.bind | |
sid = '74d38735-9e80-4cbe-9891-a6b32a8d8dfa' | |
sel = ''' | |
SELECT | |
entity_id, | |
MAX(confidence)::DOUBLE PRECISION, | |
COUNT(*) | |
FROM "entity_{0}" | |
GROUP BY entity_id | |
'''.format(sid) | |
clusters = list(engine.execute(sel)) | |
examples = {c[0]:{'attributes':c[1:], 'label': None, 'score': 1.0} for c in clusters} | |
machine = ReviewMachine(examples) | |
i = 0 | |
entity_id = examples.keys()[0] | |
while i < 100: | |
label = random.randrange(0,2) | |
machine.label(entity_id, label) | |
entity_id = machine.get_next() | |
print entity_id | |
i += 1 | |
for entity_id, d in examples.items()[200:300]: | |
example = [d['attributes']] | |
print machine.predict(example) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment