from fastapi import FastAPI
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware

paraphrasing_pipeline = ParaphraseOnnxPipeline(num_beams=8)
ner_pipeline = NEROnnxModel()
summarization_pipeline = SummarizeOnnxPipeline(num_beams=8)
keyword_pipeline = GetKeywords()

app = FastAPI()

# allow CORS requests from any host so that the JavaScript can communicate with the server
app.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

class Request(BaseModel):
    text: str    
      
...

class KeywordResponse(BaseModel):
    response: Dict[str, List[str]]
      
class AllModelsResponse(BaseModel):
    original: str
    paraphrased: ParagraphResponse
    name_entities: NERResponse
    summarized: ParagraphResponse
    keyword_synonyms: KeywordResponse

@app.post("/predict", response_model=AllModelsResponse)
async def predict(request: Request):
    paraphrased_text = ParagraphResponse(text=paraphrasing_pipeline(request.text))
    ner_text = NERResponse(render_data=ner_pipeline(request.text))
    summarized_text = ParagraphResponse(text=summarization_pipeline(request.text))
    keyword_synonyms = KeywordResponse(response=keyword_pipeline.get_synonyms_for_keywords(request.text))
    return AllModelsResponse(
        original=request.text, paraphrased=paraphrased_text, 
        name_entities=ner_text, summarized=summarized_text,
        keyword_synonyms=keyword_synonyms
    )