Skip to content

Instantly share code, notes, and snippets.

@amoghbanta
Last active November 21, 2023 12:47
Show Gist options
  • Save amoghbanta/b3286157726874fcf31a1379fc389680 to your computer and use it in GitHub Desktop.
Save amoghbanta/b3286157726874fcf31a1379fc389680 to your computer and use it in GitHub Desktop.
Backend for UNICEF's Kindly to run on Google Colab
# @title Install dependencies
!pip install fastapi uvicorn python-multipart transformers pyngrok
from fastapi import FastAPI
import uvicorn
from pyngrok import ngrok
import threading
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForSequenceClassification, AutoTokenizer,AutoModelForSeq2SeqLM
import logging
import torch
from pydantic import BaseModel
import nest_asyncio
from pyngrok import ngrok
import uvicorn
from IPython.display import HTML
# Make sure the GPU is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
logging.basicConfig(level=logging.INFO)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
#load base-offensive-model
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-offensive")
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-offensive")
model = model.to(device)
# Load the Grammarly model for grammar correction and rephrasing
grammarly_tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large")
grammarly_model = AutoModelForSeq2SeqLM.from_pretrained("grammarly/coedit-large")
grammarly_model.to(device)
class OffensiveInput(BaseModel):
text: str
@app.post("/detect_offensive/")
async def detect_offensive(input: OffensiveInput):
# Prepare the text input for the model
inputs = tokenizer(input.text, return_tensors="pt", padding=True).to(device)
# Get the model output
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
offensive_score = probabilities[:, 1].item() # Score for the offensive class
# Prepare the text input for the Grammarly model
grammarly_inputs = grammarly_tokenizer(input.text, return_tensors="pt", padding=True).to(device)
# Generate the corrected and rephrased text
grammarly_output = grammarly_model.generate(**grammarly_inputs)
corrected_text = grammarly_tokenizer.decode(grammarly_output[0], skip_special_tokens=True)
# Return the offensive score and the corrected text
return {"offensive_score": offensive_score, "corrected_text": corrected_text}
# Start server
ngrok.set_auth_token("ngrok_auth_token_goes_here")
tunnels = ngrok.get_tunnels()
print("tunnels are:", tunnels)
ngrok_tunnel = ngrok.connect(8000, domain="lately-absolute-thrush.ngrok-free.app")
app_url = ngrok_tunnel.public_url + 'detect_offensive/'
print('Your app is hosted at', ngrok_tunnel.public_url)
# Display clickable link
display(HTML(f'<a href="{app_url}" target="_blank">Open app in a new tab</a>'))
nest_asyncio.apply()
uvicorn.run(app, port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment