Last active
November 21, 2023 12:47
-
-
Save amoghbanta/b3286157726874fcf31a1379fc389680 to your computer and use it in GitHub Desktop.
Backend for UNICEF's Kindly to run on Google Colab
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @title Install dependencies | |
!pip install fastapi uvicorn python-multipart transformers pyngrok | |
from fastapi import FastAPI | |
import uvicorn | |
from pyngrok import ngrok | |
import threading | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer,AutoModelForSeq2SeqLM | |
import logging | |
import torch | |
from pydantic import BaseModel | |
import nest_asyncio | |
from pyngrok import ngrok | |
import uvicorn | |
from IPython.display import HTML | |
# Make sure the GPU is available | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f'Using device: {device}') | |
logging.basicConfig(level=logging.INFO) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
#load base-offensive-model | |
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-offensive") | |
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-offensive") | |
model = model.to(device) | |
# Load the Grammarly model for grammar correction and rephrasing | |
grammarly_tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large") | |
grammarly_model = AutoModelForSeq2SeqLM.from_pretrained("grammarly/coedit-large") | |
grammarly_model.to(device) | |
class OffensiveInput(BaseModel): | |
text: str | |
@app.post("/detect_offensive/") | |
async def detect_offensive(input: OffensiveInput): | |
# Prepare the text input for the model | |
inputs = tokenizer(input.text, return_tensors="pt", padding=True).to(device) | |
# Get the model output | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
offensive_score = probabilities[:, 1].item() # Score for the offensive class | |
# Prepare the text input for the Grammarly model | |
grammarly_inputs = grammarly_tokenizer(input.text, return_tensors="pt", padding=True).to(device) | |
# Generate the corrected and rephrased text | |
grammarly_output = grammarly_model.generate(**grammarly_inputs) | |
corrected_text = grammarly_tokenizer.decode(grammarly_output[0], skip_special_tokens=True) | |
# Return the offensive score and the corrected text | |
return {"offensive_score": offensive_score, "corrected_text": corrected_text} | |
# Start server | |
ngrok.set_auth_token("ngrok_auth_token_goes_here") | |
tunnels = ngrok.get_tunnels() | |
print("tunnels are:", tunnels) | |
ngrok_tunnel = ngrok.connect(8000, domain="lately-absolute-thrush.ngrok-free.app") | |
app_url = ngrok_tunnel.public_url + 'detect_offensive/' | |
print('Your app is hosted at', ngrok_tunnel.public_url) | |
# Display clickable link | |
display(HTML(f'<a href="{app_url}" target="_blank">Open app in a new tab</a>')) | |
nest_asyncio.apply() | |
uvicorn.run(app, port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment