Created
September 21, 2016 15:46
-
-
Save hongthaiphi/3c5e04b0ceb90273103115031624429e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TFIDFPredictor: | |
def __init__(self): | |
self.vectorizer = TfidfVectorizer() | |
def train(self, data): | |
self.vectorizer.fit(np.append(data.Context.values,data.Utterance.values)) | |
def predict(self, context, utterances): | |
# Convert context and utterances into tfidf vector | |
vector_context = self.vectorizer.transform([context]) | |
vector_doc = self.vectorizer.transform(utterances) | |
# The dot product measures the similarity of the resulting vectors | |
result = np.dot(vector_doc, vector_context.T).todense() | |
result = np.asarray(result).flatten() | |
# Sort by top results and return the indices in descending order | |
return np.argsort(result, axis=0)[::-1] | |
# Evaluate TFIDF predictor | |
pred = TFIDFPredictor() | |
pred.train(train_df) | |
y = [pred.predict(test_df.Context[x], test_df.iloc[x,1:].values) for x in range(len(test_df))] | |
for n in [1, 2, 5, 10]: | |
print("Recall @ ({}, 10): {:g}".format(n, evaluate_recall(y, y_test, n))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment