Last active
March 7, 2022 19:57
-
-
Save tezansahu/e9a414aa476fe7c0fa96bf7670bd6e36 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlfoundry as mlf | |
import random | |
import os | |
from simpletransformers.classification import ClassificationModel | |
# Function to load a Simple Transformers model & predict sentiment | |
def predict_model(model_params, input_headline): | |
try: | |
class_name_map = { | |
0: "negative", | |
1: "neutral", | |
2: "positive" | |
} | |
# Load the saved model | |
model_loaded = ClassificationModel(model_params["model_type"], os.path.join("outputs", model_params["model_name"])) | |
# Perform inference and return the result | |
model_output = model_loaded.predict([input_headline.lower()])[0][0] | |
return class_name_map[model_output] | |
except Exception: | |
# If an error is thrown, return one of the three possible sentiments randomly to avoid runtime errors in the UI | |
return random.choice(["negative", "neutral", "positive"]) | |
# Function to predict sentiment for headline using fine-tuned RoBERTa | |
def predict_roberta(input_headline: str) -> str: | |
model_params = { | |
"model_type": "roberta", | |
"model_name": "roberta" | |
} | |
return predict_model(model_params, input_headline) | |
mlf_api = mlf.get_client() | |
mlf_run = mlf_api.create_run(project_name="financial_sentiment_analysis_webapp") | |
# Define a model demo interface using webapp API | |
raw_in, raw_out = mlf_run.webapp( | |
fn=predict_roberta, inputs="text", outputs="text" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment