Last active
April 18, 2020 14:09
-
-
Save Mageswaran1989/1197796d81674444391b25074f79b989 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import gin | |
from sklearn.base import BaseEstimator, TransformerMixin | |
import nltk | |
from snorkel.labeling import labeling_function | |
from snorkel.labeling import LFApplier | |
from snorkel.labeling import LFAnalysis | |
from snorkel.labeling import LabelModel | |
CEND = '\33[0m' | |
CBLUE = '\33[34m' | |
CYELLOW2 = '\33[93m' | |
CRED = '\33[31m' | |
CGREEN2 = '\33[92m' | |
def print_info(*args): | |
""" | |
Prints the string in green color | |
:param args: user string information | |
:return: stdout | |
""" | |
print(CGREEN2 + str(*args) + CEND) | |
def print_error(*args): | |
""" | |
Prints the string in red color | |
:param args: user string information | |
:return: stdout | |
""" | |
print(CRED + str(*args) + CEND) | |
def print_warn(*args): | |
""" | |
Prints the string in yellow color | |
:param args: user string information | |
:return: stdout | |
""" | |
print(CYELLOW2 + str(*args) + CEND) | |
def print_debug(*args): | |
""" | |
Prints the string in blue color | |
:param args: user string information | |
:return: stdout | |
""" | |
print(CBLUE + str(*args) + CEND) | |
class AIKeyWords(object): | |
AI = "#AI|Artificial Intelligence|robotics" | |
ML = "machinelearningengineer|Machine Learning|scikit|#ML|mathematics" | |
DL = "DeepLearning|Deep Learning|#DL|Tensorflow|Pytorch|Neural Network|NeuralNetwork" | |
CV = "computervision|computer vision|machine vision|machinevision|convolutional network|convnet|image processing" | |
NLP = "NLP|naturallanguageprocessing|natural language processing|text processing|text analytics|nltk|spacy" | |
DATA = "iot|datasets|dataengineer|analytics|bigdata|big data|data science|data analytics|data insights|data mining|distributed computing|parallel processing|apache spark|hadoop|apache hive|airflow|mlflow|apache kafka|hdfs|apache|kafka" | |
TWEET_HASH_TAGS = "dataanalysis|AugmentedIntelligence|datascience|machinelearning|rnd|businessintelligence|DigitalTransformation|datamanagement|ArtificialIntelligence" | |
FALSE_POSITIVE = "gpu|nvidia|maths|mathematics|intelligence|conspiracy|astrology|vedic|tamil|text|computer|ebook|pdf|learning|big|insights|processing|network|machine|artifical|data|science|parallel|computing|deep|vision|natural|language|data" | |
RANDOM_TOPICS = "nature|climate|space|earth|animals|plants|astrology|horoscope|occult|hidden science|conspiracy|hinduism|hindu|vedic" | |
POSITIVE = AI + "|" + ML + "|" + DL + "|" + CV + "|" + NLP + "|" + DATA + "|" + TWEET_HASH_TAGS | |
ALL = POSITIVE + "|" + FALSE_POSITIVE + "|" + RANDOM_TOPICS | |
class SSPTweetLabeller(BaseEstimator, TransformerMixin): | |
""" | |
Snorkel Transformer uses LFs to train a Label Model, that can annotate AI text and non AI text | |
:param input_col: Name of the input text column if Dataframe is used | |
:param output_col: Name of the ouput label column if Dataframe is used | |
""" | |
# Set voting values. | |
# all other tweets | |
ABSTAIN = -1 | |
# tweets that talks about science, AI, data | |
POSITIVE = 1 | |
# tweets that are not | |
NEGATIVE = 0 | |
def __init__(self, | |
input_col="text", | |
output_col="slabel"): | |
# LFs needs to be static or normal function | |
self._labelling_functions = [self.is_ai_tweet, | |
self.is_not_ai_tweet, | |
self.not_data_science, | |
self.not_neural_network, | |
self.not_big_data, | |
self.not_nlp, | |
self.not_ai, | |
self.not_cv] | |
self._input_col = input_col | |
self._output_col = output_col | |
self._list_applier = LFApplier(lfs=self._labelling_functions) | |
self._label_model = LabelModel(cardinality=2, verbose=True) | |
def fit(self, X, y=None): | |
""" | |
:param X: (Dataframe) / (List) Input text | |
:param y: None | |
:return: Numpy Array [num of samples, num of LF functions] | |
""" | |
if isinstance(X, str): | |
X = [X] | |
if isinstance(X, pd.DataFrame): | |
text_list = X[self._input_col] | |
X_labels = self._list_applier.apply(text_list) | |
print_info(LFAnalysis(L=X_labels, lfs=self._labelling_functions).lf_summary()) | |
print_info("Training LabelModel") | |
self._label_model.fit(L_train=X_labels, n_epochs=500, log_freq=100, seed=42) | |
elif isinstance(X, list): | |
X_labels = self._list_applier.apply(X) | |
print_info(LFAnalysis(L=X_labels, lfs=self._labelling_functions).lf_summary()) | |
print_info("Training LabelModel") | |
self._label_model.fit(L_train=X_labels, n_epochs=500, log_freq=100, seed=42) | |
else: | |
raise RuntimeError("Unknown type...") | |
return self | |
def normalize_prob(self, res): | |
return [1 if r > 0.5 else 0 for r in res] | |
def transform(self, X, y=None): | |
if isinstance(X, pd.DataFrame): | |
if self._input_col: | |
res = self.predict(X[self._input_col])[:, 1] | |
X[self._output_col] = self.normalize_prob(res) | |
return X | |
elif isinstance(X, list): | |
res = self.predict(X)[:, 1] | |
return self.normalize_prob(res) | |
elif isinstance(X, str): | |
res = self.predict([X])[:, 1] | |
return self.normalize_prob(res)[0] | |
def predict(self, X): | |
return self._label_model.predict_proba(L=self._list_applier.apply(X)) | |
def evaluate(self, X, y): | |
if isinstance(X, list): | |
X_labels = self._list_applier.apply(X) | |
label_model_acc = self._label_model.score(L=X_labels, Y=y, tie_break_policy="random")[ | |
"accuracy" | |
] | |
print_info(LFAnalysis(L=X_labels, lfs=self._labelling_functions).lf_summary()) | |
print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%") | |
elif isinstance(X, pd.DataFrame): | |
text_list = X[self._input_col] | |
X_labels = self._list_applier.apply(text_list) | |
label_model_acc = self._label_model.score(L=X_labels, Y=y, tie_break_policy="random")[ | |
"accuracy" | |
] | |
print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%") | |
else: | |
raise RuntimeError("Unknown type...") | |
@staticmethod | |
def positive_search(data, key_words): | |
data = data.replace("#", "").replace("@", "") | |
for keyword in key_words: | |
if f' {keyword.lower()} ' in f' {data.lower()} ': | |
return SSPTweetLabeller.POSITIVE | |
return SSPTweetLabeller.ABSTAIN | |
@staticmethod | |
def negative_search(data, positive_keywords, false_positive_keywords): | |
data = data.replace("#", "").replace("@", "") | |
positive = False | |
false_positive = False | |
for keyword in positive_keywords: | |
if f' {keyword.lower()} ' in f' {data.lower()} ': | |
positive = True | |
for keyword in false_positive_keywords: | |
if f' {keyword.lower()} ' in f' {data.lower()} ': | |
false_positive = True | |
if false_positive and not positive: | |
# print_info(data) | |
return SSPTweetLabeller.NEGATIVE | |
return SSPTweetLabeller.ABSTAIN | |
@staticmethod | |
def bigram_check(x, word1, word2): | |
# Get bigrams and check tuple exists or not | |
bigrm = list(nltk.bigrams(x.split())) | |
bigrm = list(map(' '.join, bigrm)) | |
count = 0 | |
for pair in bigrm: | |
if word1 in pair and word2 not in pair: | |
count += 1 | |
if count > 0: | |
return SSPTweetLabeller.NEGATIVE | |
else: | |
return SSPTweetLabeller.ABSTAIN | |
@staticmethod | |
@labeling_function() | |
def is_ai_tweet(x): | |
return SSPTweetLabeller.positive_search(x, AIKeyWords.POSITIVE.split("|")) | |
@staticmethod | |
@labeling_function() | |
def is_not_ai_tweet(x): | |
return SSPTweetLabeller.negative_search(data=x, | |
positive_keywords=AIKeyWords.POSITIVE.split("|"), | |
false_positive_keywords=AIKeyWords.FALSE_POSITIVE.split("|")) | |
@staticmethod | |
@labeling_function() | |
def not_data_science(x): | |
return SSPTweetLabeller.bigram_check(x, "data", "science") | |
@staticmethod | |
@labeling_function() | |
def not_neural_network(x): | |
return SSPTweetLabeller.bigram_check(x, "neural", "network") | |
@staticmethod | |
@labeling_function() | |
def not_big_data(x): | |
return SSPTweetLabeller.bigram_check(x, "big", "data") | |
@staticmethod | |
@labeling_function() | |
def not_nlp(x): | |
return SSPTweetLabeller.bigram_check(x, "natural", "language") | |
@staticmethod | |
@labeling_function() | |
def not_ai(x): | |
return SSPTweetLabeller.bigram_check(x, "artificial", "intelligence") | |
@staticmethod | |
@labeling_function() | |
def not_cv(x): | |
return SSPTweetLabeller.bigram_check(x, "computer", "vision") | |
if __name__ == "__main__": | |
ai_snorkel_labeler = SSPTweetLabeller() | |
df = pd.read_parquet("train.parquet", engine="fastparquet") | |
print_info(df) | |
ai_snorkel_labeler.fit(df) | |
ai_snorkel_labeler.evaluate(df, df["label"]) | |
""" | |
(vh) mageswarand@IMCHLT276:~/ssp/data/dump/1197796d81674444391b25074f79b989$ python snorkel_ai_text_annotater.py | |
id_str created_at source text ... slabel text_id label id | |
0 1248501381943554050 Fri Apr 10 06:41:40 +0000 2020 WordPress.com The Ancient History of Artificial Intelligence... ... 1 1 1 0 | |
1 1248516449553039369 Fri Apr 10 07:41:32 +0000 2020 Twitter Web App 2,000 #Drones pick up a 40-ton truck by @Scani... ... 1 2 1 1 | |
2 1248506778788098054 Fri Apr 10 07:03:06 +0000 2020 Paper.li Small Business Security is out! https://t.co/q... ... 1 3 1 2 | |
3 1248500801313492992 Fri Apr 10 06:39:21 +0000 2020 Twitter for Android Drelly and I have been learning to draw over t... ... 0 4 0 3 | |
4 1248503150241140736 Fri Apr 10 06:48:41 +0000 2020 Twitter for iPhone Why is that all such sob stories are abt musli... ... 0 5 0 4 | |
... ... ... ... ... ... ... ... ... ... | |
17578 1248581535726534656 Fri Apr 10 12:00:10 +0000 2020 Facelift-Cloud “Culture and habits can slow the inertia of ch... ... 0 17579 0 17578 | |
17579 1248503137964453894 Fri Apr 10 06:48:38 +0000 2020 Twitter Web App RT @MrWednesday11: @KevinCate Here some data @... ... 0 17580 0 17579 | |
17580 1248503293287911425 Fri Apr 10 06:49:15 +0000 2020 Twitter Web App Bluetooth Mouse, Inphic Multi-Device Silent Re... ... 0 17581 0 17580 | |
17581 1248502759592255488 Fri Apr 10 06:47:08 +0000 2020 Twitter for iPhone RT @ellajamis: I crave a love so deep it’ll ma... ... 0 17582 0 17581 | |
17582 1248579982567276544 Fri Apr 10 11:53:59 +0000 2020 Twitter for Android @bbcmicrobot 10 REM My first program20 MODE 23... ... 0 17583 0 17582 | |
[17583 rows x 11 columns] | |
17583it [00:03, 5379.99it/s] | |
j Polarity Coverage Overlaps Conflicts | |
is_ai_tweet 0 [1] 0.314338 0.054371 0.054371 | |
is_not_ai_tweet 1 [0] 0.346300 0.161463 0.000000 | |
not_data_science 2 [0] 0.109594 0.096002 0.043110 | |
not_neural_network 3 [0] 0.002218 0.001706 0.001479 | |
not_big_data 4 [0] 0.110675 0.103054 0.005517 | |
not_nlp 5 [0] 0.013308 0.011830 0.000341 | |
not_ai 6 [0] 0.008929 0.007621 0.006768 | |
not_cv 7 [0] 0.010863 0.009441 0.002047 | |
Training LabelModel | |
17583it [00:03, 5291.93it/s] | |
Label Model Accuracy: 84.6% | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For medium post @ https://medium.com/@mageswaran1989/big-data-playground-for-engineers-snorkel-scikit-transformer-6da6d0bcf109