import requests
import json
from pyspark.sql.types import FloatType, ArrayType
from pyspark.sql.functions import udf
from ssp.logger.pretty_print import print_error, print_warn
from ssp.dl.tf.classifier import NaiveTextClassifier
from ssp.logger.pretty_print import print_info
[docs]def predict_text_class(text, url, tokenizer_path):
classifer = NaiveTextClassifier()
# TODO is this right way to load the tokenizer? Move this to a flask API as one extra layer
classifer.load_tokenizer(tokenizer_path=tokenizer_path)
text = list(classifer.transform([text])[0])
text = [int(t) for t in text]
data = json.dumps({"signature_name": "serving_default", "instances": [text]})
headers = {"content-type": "application/json"}
json_response = requests.post(url, data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']
return float(predictions[0][1])
schema = FloatType()
[docs]def get_text_classifier_udf(is_docker, tokenizer_path):
if is_docker: #when the example is trigger inside the Docker environment
url = "http://host.docker.internal:30125/v1/models/naive_text_clf:predict"
return udf(lambda x: predict_text_class(text=x, tokenizer_path=tokenizer_path, url=url), schema)
else:
url = "http://localhost:8501/v1/models/naive_text_clf:predict"
return udf(lambda x: predict_text_class(text=x, tokenizer_path=tokenizer_path, url=url), schema)
[docs]def predict(text):
print("\n")
print_info(f"Text : {text} ")
try:
URL = "http://host.docker.internal:30125/v1/models/naive_text_clf:predict"
data = predict_text_class(text=text,
url=URL,
tokenizer_path="~/ssp/model/raw_tweet_dataset_0/naive_text_classifier/1/")
print_warn(URL)
print(data)
exit(0)
except:
pass
try:
URL = "http://localhost:8501/v1/models/naive_text_clf:predict"
data = predict_text_class(
text=text,
url=URL,
tokenizer_path="~/ssp/model/raw_tweet_dataset_0/naive_text_classifier/1/")
print_warn(URL)
print(data)
exit(0)
except:
pass
try:
URL = "http://127.0.0.1:30125/v1/models/naive_text_clf:predict"
data = predict_text_class(
text=text,
url=URL,
tokenizer_path="~/ssp/model/raw_tweet_dataset_0/naive_text_classifier/1/")
print_warn(URL)
print(data)
exit(0)
except:
pass
if __name__ == "__main__":
predict("📰Machine learning as a tool to explore cognitive profiles of epileptic patients. Neuropsychological data science are meaningful artificial intelligence 📈🔍| Home https://t.co/cAQ2vZYxk2")
predict("This is a random text to check whats the prediction...home so it gets classified as 0")
# export PYTHONPATH=$(pwd)/src/:$PYTHONPATH