Created
January 9, 2020 17:13
-
-
Save mrezzamoradi/fac8720e9285a6710e6c251f6b32b6c2 to your computer and use it in GitHub Desktop.
Sample API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List | |
from starlette.staticfiles import StaticFiles | |
from fastapi import FastAPI, Query, Body, Form, File, UploadFile, HTTPException | |
from crf_service.settings.paths import TRAINSETS_DIR | |
from crf_service.sentence_labeler.crf_handler import crf_handler | |
from crf_service.sentence_labeler.feature import feature_manager | |
from crf_service.models import FeatureData, LookupFeature | |
app = FastAPI(title='CRF Service') | |
app.mount('/trainset', StaticFiles(directory=TRAINSETS_DIR), name='trainset') | |
def save_trainset(file_name: str, content: bytes) -> str: | |
""" | |
save trainset. This will overwrite the trainset with the same name if it already exists. | |
:param file_name: name of the trainset | |
:param content: content of the trainset | |
:return: path to the trainset | |
""" | |
trainset_path = os.path.join(TRAINSETS_DIR, file_name) | |
with open(trainset_path, 'wb') as f: | |
f.write(content) | |
return trainset_path | |
@app.get('/models', summary='get list of CRF models', tags=['CRF models']) | |
async def get_models(detailed: bool = Query(False, title='show accuracy reports')): | |
""" | |
get list of the CRF models loaded to CRF handler which can be used for tagging. | |
- **detailed**: whether to show details or not | |
""" | |
return crf_handler.get_models_list(with_stats=detailed) | |
@app.get('/models/{model_name}/reports', summary='get accuracy reports of a model', tags=['CRF models']) | |
async def get_reports(model_name: str): | |
""" | |
get accuracy reports generated when the `model_name` is trained | |
- **model_name**: name of the model | |
""" | |
if model_name not in crf_handler.get_models_list(): | |
raise HTTPException(status_code=404, detail=f'model not found: {model_name}') | |
return crf_handler.get_model_stats(model_name=model_name) | |
@app.get('/models/{model_name}/configs', summary='get model configs', tags=['CRF models']) | |
async def get_configs(model_name: str): | |
""" | |
get the configs which a model is trained with | |
- **model_name**: name of the model | |
""" | |
if model_name not in crf_handler.get_models_list(): | |
raise HTTPException(status_code=404, detail=f'model not found: {model_name}') | |
configs = crf_handler.get_config(model_name=model_name) | |
configs['paths_to_trainsets'] = [f'/trainset/{os.path.basename(item)}' for item in configs['paths_to_trainsets']] | |
return configs | |
@app.post('/models/{model_name}/tag', summary='tag a sentence', tags=['CRF models']) | |
async def tag_sentence(model_name: str, sentence: str = Body(..., embed=True)): | |
""" | |
tag a sentence using a given model. | |
- **model_name**: name of the model | |
- **sentence**: the string to tag | |
""" | |
tagged_sentence = crf_handler.tag_sentence(model_name=model_name, sentence=sentence) | |
return {'sentence': sentence, 'tagged_sentence': tagged_sentence} | |
@app.post('/models/{model_name}/train', summary='train a model', tags=['CRF models']) | |
async def train_model(model_name: str, trainset: List[UploadFile] = File(...), features: str = Form(None)): | |
""" | |
train a model using the given trainset and features. If the model already exists, this will replace the model with | |
the new one. | |
- **model_name**: name of the model | |
- **trainset**: a csv file formatted for CRF model training. The request can have many `trainset` fields | |
(if you want to use multiple trainsets) | |
- **features**: a comma-separated list of feature names e.g. city | city,number,currency | |
""" | |
paths_to_trainsets = [] | |
for file in trainset: | |
trainset_path = save_trainset(file_name=file.filename, content=await file.read()) | |
paths_to_trainsets.append(trainset_path) | |
if features: | |
features = features.split(',') | |
crf_handler.train_model(model_name=model_name, paths_to_trainsets=paths_to_trainsets, features=features) | |
return { | |
'configs': crf_handler.get_config(model_name=model_name), | |
'reports': crf_handler.get_model_stats(model_name=model_name)} | |
@app.get('/features/data', summary='show features data', tags=['Features']) | |
async def get_features_data(): | |
"""get features data used in feature_manager""" | |
return feature_manager.feature_data | |
@app.post('/features/data/add', summary='add feature data', tags=['Features']) | |
async def add_features_data(feature_info: FeatureData): | |
"""add a set of phrases under a label to features data""" | |
feature_manager.add_feature_data(label=feature_info.label, phrases=set(feature_info.phrases)) | |
return {'response': f'added {feature_info.label!r} to the features data'} | |
@app.post('/features/add/lookup', summary='add lookup feature', tags=['Features']) | |
async def add_lookup_feature(feature: LookupFeature): | |
""" | |
adds a lookup feature to the feature manager | |
Note: you must add feature data before adding a lookup feature | |
""" | |
if feature.label not in feature_manager.feature_data: | |
raise HTTPException( | |
status_code=404, detail=f'{feature.label!r} not found in feature data. Consider adding them first') | |
feature_manager.add_lookup_handler(label=feature.label, phrase_max_length=feature.phrase_max_length) | |
return {'response': f'added a handler to handle {feature.label!r} lookup feature'} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment