Skip to content

Instantly share code, notes, and snippets.

@kylegallatin
Created January 16, 2020 20:58
Show Gist options
  • Save kylegallatin/0860a1b51101c7bd9c2fcc0d7b6f0906 to your computer and use it in GitHub Desktop.
Save kylegallatin/0860a1b51101c7bd9c2fcc0d7b6f0906 to your computer and use it in GitHub Desktop.
import pickle
import numpy as np
import pandas as pd
import string
import sklearn.feature_extraction
from sklearn.metrics.pairwise import linear_kernel
class PubmedTfidfTrainServe:
def __init__(self):
self.vectorizer = None
self.vectorizer_name = "vectorizer.pickle"
self.lowercase = True
self.text_data = None
@staticmethod
def preprocess_data(text):
text = str(text).lower()
text = text.translate(str.maketrans('', '', string.punctuation))
return text.strip()
def load_data(self, path, column):
df = pd.read_csv(path)
self.text_data = df["Abstract"]
self.text_data = [self.preprocess_data(x) for x in self.text_data]
def train(self):
self.vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=self.lowercase)
self.tfidf = self.vectorizer.fit_transform(self.text_data)
self.save_model()
def save_model(self):
pickle.dump(self.vectorizer, open(self.vectorizer_name, "wb"))
def load_model(self, model_path):
self.vectorizer = pickle.load(open(model_path, "rb"))
def search(self, text, n_results=10):
if not self.vectorizer:
self.load_model()
vector = self.vectorizer.transform([str(text)])
cosine_similarities = linear_kernel(vector, self.tfidf).flatten()
related_doc_indicies = cosine_similarities.argsort()[:-n_results:-1]
return [self.text_data[i] for i in related_doc_indicies]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment