Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active March 7, 2022 19:57
Show Gist options
  • Save tezansahu/e9a414aa476fe7c0fa96bf7670bd6e36 to your computer and use it in GitHub Desktop.
Save tezansahu/e9a414aa476fe7c0fa96bf7670bd6e36 to your computer and use it in GitHub Desktop.
import mlfoundry as mlf
import random
import os
from simpletransformers.classification import ClassificationModel
# Function to load a Simple Transformers model & predict sentiment
def predict_model(model_params, input_headline):
try:
class_name_map = {
0: "negative",
1: "neutral",
2: "positive"
}
# Load the saved model
model_loaded = ClassificationModel(model_params["model_type"], os.path.join("outputs", model_params["model_name"]))
# Perform inference and return the result
model_output = model_loaded.predict([input_headline.lower()])[0][0]
return class_name_map[model_output]
except Exception:
# If an error is thrown, return one of the three possible sentiments randomly to avoid runtime errors in the UI
return random.choice(["negative", "neutral", "positive"])
# Function to predict sentiment for headline using fine-tuned RoBERTa
def predict_roberta(input_headline: str) -> str:
model_params = {
"model_type": "roberta",
"model_name": "roberta"
}
return predict_model(model_params, input_headline)
mlf_api = mlf.get_client()
mlf_run = mlf_api.create_run(project_name="financial_sentiment_analysis_webapp")
# Define a model demo interface using webapp API
raw_in, raw_out = mlf_run.webapp(
fn=predict_roberta, inputs="text", outputs="text"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment