Source code for ssp.snorkel.labelling_function
#!/usr/bin/env python
__author__ = "Mageswaran Dhandapani"
__copyright__ = "Copyright 2020, The Spark Structured Playground Project"
__credits__ = []
__license__ = "Apache License"
__version__ = "2.0"
__maintainer__ = "Mageswaran Dhandapani"
__email__ = "mageswaran1989@gmail.com"
__status__ = "Education Purpose"
import pandas as pd
import gin
from sklearn.base import BaseEstimator, TransformerMixin
import nltk
from snorkel.labeling import labeling_function
from snorkel.labeling import LFApplier
from snorkel.labeling import LFAnalysis
from snorkel.labeling import LabelModel
from ssp.logger.pretty_print import print_error
from ssp.logger.pretty_print import print_info
from ssp.posgress.dataset_base import PostgresqlDatasetBase
from ssp.utils.ai_key_words import AIKeyWords
[docs]class SSPTweetLabeller(BaseEstimator, TransformerMixin):
    """
    Snorkel Transformer uses LFs to train a Label Model, that can annotate AI text and non AI text
    :param input_col: Name of the input text column if Dataframe is used
    :param output_col: Name of the ouput label column if Dataframe is used
    """
    # Set voting values.
    # all other tweets
    ABSTAIN = -1
    # tweets that talks about science, AI, data
    POSITIVE = 1
    # tweets that are not
    NEGATIVE = 0
    def __init__(self,
                 input_col="text",
                 output_col="slabel"):
        # LFs needs to be static or normal function
        self._labelling_functions = [self.is_ai_tweet,
                                     self.is_not_ai_tweet,
                                     self.not_data_science,
                                     self.not_neural_network,
                                     self.not_big_data,
                                     self.not_nlp,
                                     self.not_ai,
                                     self.not_cv]
        self._input_col = input_col
        self._output_col = output_col
        self._list_applier = LFApplier(lfs=self._labelling_functions)
        self._label_model = LabelModel(cardinality=2, verbose=True)
[docs]    def fit(self, X, y=None):
        """
        :param X: (Dataframe) / (List) Input text
        :param y: None
        :return: Numpy Array [num of samples, num of LF functions]
        """
        if isinstance(X, str):
            X = [X]
        if isinstance(X, pd.DataFrame):
            text_list = X[self._input_col]
            X_labels = self._list_applier.apply(text_list)
            print_info(LFAnalysis(L=X_labels, lfs=self._labelling_functions).lf_summary())
            print_info("Training LabelModel")
            self._label_model.fit(L_train=X_labels, n_epochs=500, log_freq=100, seed=42)
        elif isinstance(X, list):
            X_labels = self._list_applier.apply(X)
            print_info(LFAnalysis(L=X_labels, lfs=self._labelling_functions).lf_summary())
            print_info("Training LabelModel")
            self._label_model.fit(L_train=X_labels, n_epochs=500, log_freq=100, seed=42)
        else:
            raise RuntimeError("Unknown type...")
        
        return self
[docs]    def transform(self, X, y=None):
        if isinstance(X, pd.DataFrame):
            if self._input_col:
               res  = self.predict(X[self._input_col])[:, 1]
               X[self._output_col] = self.normalize_prob(res)
               return X
        elif isinstance(X, list):
            res = self.predict(X)[:, 1]
            return self.normalize_prob(res)
        elif isinstance(X, str):
            res = self.predict([X])[:, 1]
            return self.normalize_prob(res)[0]
    
[docs]    def evaluate(self, X, y):
        if isinstance(X, list):
            X_labels = self._list_applier.apply(X)
            label_model_acc = self._label_model.score(L=X_labels, Y=y, tie_break_policy="random")[
                "accuracy"
            ]
            print_info(LFAnalysis(L=X_labels, lfs=self._labelling_functions).lf_summary())
            print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%")
        elif isinstance(X, pd.DataFrame):
            text_list = X[self._input_col]
            X_labels = self._list_applier.apply(text_list)
            label_model_acc = self._label_model.score(L=X_labels, Y=y, tie_break_policy="random")[
                "accuracy"
            ]
            print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%")
        else:
            raise RuntimeError("Unknown type...")
[docs]    @staticmethod
    def positive_search(data, key_words):
        data = data.replace("#", "").replace("@", "")
        for keyword in key_words:
            if f' {keyword.lower()} ' in f' {data.lower()} ':
                return SSPTweetLabeller.POSITIVE
        return SSPTweetLabeller.ABSTAIN
[docs]    @staticmethod
    def negative_search(data, positive_keywords, false_positive_keywords):
        data = data.replace("#", "").replace("@", "")
        positive = False
        false_positive = False
        for keyword in positive_keywords:
            if f' {keyword.lower()} ' in f' {data.lower()} ':
                positive = True
        for keyword in false_positive_keywords:
            if f' {keyword.lower()} ' in f' {data.lower()} ':
                false_positive = True
        if false_positive and not positive:
            # print_info(data)
            return SSPTweetLabeller.NEGATIVE
        return SSPTweetLabeller.ABSTAIN
[docs]    @staticmethod
    def bigram_check(x, word1, word2):
        # Get bigrams and check tuple exists or not
        bigrm = list(nltk.bigrams(x.split()))
        bigrm = list(map(' '.join, bigrm))
        count = 0
        for pair in bigrm:
            if word1 in pair and word2 not in pair:
                count += 1
        if count > 0:
            return SSPTweetLabeller.NEGATIVE
        else:
            return SSPTweetLabeller.ABSTAIN
    @staticmethod
    @labeling_function()
    def is_ai_tweet(x):
        return SSPTweetLabeller.positive_search(x, AIKeyWords.POSITIVE.split("|"))
    @staticmethod
    @labeling_function()
    def is_not_ai_tweet(x):
        return SSPTweetLabeller.negative_search(data=x,
                                                positive_keywords=AIKeyWords.POSITIVE.split("|"),
                                                false_positive_keywords=AIKeyWords.FALSE_POSITIVE.split("|"))
    @staticmethod
    @labeling_function()
    def not_data_science(x):
        return SSPTweetLabeller.bigram_check(x, "data", "science")
    @staticmethod
    @labeling_function()
    def not_neural_network(x):
        return SSPTweetLabeller.bigram_check(x, "neural", "network")
    @staticmethod
    @labeling_function()
    def not_big_data(x):
        return SSPTweetLabeller.bigram_check(x, "big", "data")
    @staticmethod
    @labeling_function()
    def not_nlp(x):
        return SSPTweetLabeller.bigram_check(x, "natural", "language")
    @staticmethod
    @labeling_function()
    def not_ai(x):
        return SSPTweetLabeller.bigram_check(x, "artificial", "intelligence")
    @staticmethod
    @labeling_function()
    def not_cv(x):
        return SSPTweetLabeller.bigram_check(x, "computer", "vision")
[docs]@gin.configurable
class SSPLabelEvaluator(PostgresqlDatasetBase):
    def __init__(self,
                 text_column="text",
                 label_column="label",
                 raw_tweet_table_name_prefix="raw_tweet_dataset",
                 postgresql_host="localhost",
                 postgresql_port="5432",
                 postgresql_database="sparkstreamingdb",
                 postgresql_user="sparkstreaming",
                 postgresql_password="sparkstreaming"):
        PostgresqlDatasetBase.__init__(self,
                                       text_column=text_column,
                                       label_output_column=label_column,
                                       raw_tweet_table_name_prefix=raw_tweet_table_name_prefix,
                                       postgresql_host=postgresql_host,
                                       postgresql_port=postgresql_port,
                                       postgresql_database=postgresql_database,
                                       postgresql_user=postgresql_user,
                                       postgresql_password=postgresql_password)
        self._snorkel_labeler = SSPTweetLabeller()
[docs]    def run_labeler(self, version=0):
        raw_tweet_dataset_df_deduplicated, test_df, dev_df, \
                snorkel_train_df, train_df = self.get_processed_datasets(version=version)
        self._snorkel_labeler.fit(snorkel_train_df)
        self._snorkel_labeler.evaluate(test_df, test_df[self._label_output_column])
        # snorkel_train_df["label"] = snorkel_train_df["text"].apply(lambda x: SSPTweetLabeller.is_ai_tweet(x))
        # print_info(snorkel_train_df["label"].value_counts())
        # print_error(snorkel_train_df[snorkel_train_df["label"]==0]["text"].tolist()[:10])
        # print_info(snorkel_train_df[snorkel_train_df["label"]==1]["text"].tolist()[:10])
        # res = self._snorkel_labeler.predict(train_df[self._text_column])
        # res = res[:, 1]
        # res = [1 if r >= 0.5 else 0 for r in res]
        # print_error(train_df.shape[0])
        # print_info(sum(res))
        # train_df["snorkel_label"] = res
        # for label, group in train_df[["text", "snorkel_label"]].groupby("snorkel_label"):
        #     if label == 1:
        #         print(label)
        #         print_info(group.shape[0])
        #         group.reset_index(inplace=True)
        #         # print_info("\n".join(group["text"].tolist()[:10]))
        #         group["label"] = group["text"].apply(lambda x: SSPTweetLabeller.is_ai_tweet(x))
        #         print_info("\n".join(group[group["label"]==1]["text"].tolist()[:100]))