Skip to content

Instantly share code, notes, and snippets.

@mrezzamoradi
Created January 9, 2020 17:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrezzamoradi/fac8720e9285a6710e6c251f6b32b6c2 to your computer and use it in GitHub Desktop.
Save mrezzamoradi/fac8720e9285a6710e6c251f6b32b6c2 to your computer and use it in GitHub Desktop.
Sample API
import os
from typing import List
from starlette.staticfiles import StaticFiles
from fastapi import FastAPI, Query, Body, Form, File, UploadFile, HTTPException
from crf_service.settings.paths import TRAINSETS_DIR
from crf_service.sentence_labeler.crf_handler import crf_handler
from crf_service.sentence_labeler.feature import feature_manager
from crf_service.models import FeatureData, LookupFeature
app = FastAPI(title='CRF Service')
app.mount('/trainset', StaticFiles(directory=TRAINSETS_DIR), name='trainset')
def save_trainset(file_name: str, content: bytes) -> str:
"""
save trainset. This will overwrite the trainset with the same name if it already exists.
:param file_name: name of the trainset
:param content: content of the trainset
:return: path to the trainset
"""
trainset_path = os.path.join(TRAINSETS_DIR, file_name)
with open(trainset_path, 'wb') as f:
f.write(content)
return trainset_path
@app.get('/models', summary='get list of CRF models', tags=['CRF models'])
async def get_models(detailed: bool = Query(False, title='show accuracy reports')):
"""
get list of the CRF models loaded to CRF handler which can be used for tagging.
- **detailed**: whether to show details or not
"""
return crf_handler.get_models_list(with_stats=detailed)
@app.get('/models/{model_name}/reports', summary='get accuracy reports of a model', tags=['CRF models'])
async def get_reports(model_name: str):
"""
get accuracy reports generated when the `model_name` is trained
- **model_name**: name of the model
"""
if model_name not in crf_handler.get_models_list():
raise HTTPException(status_code=404, detail=f'model not found: {model_name}')
return crf_handler.get_model_stats(model_name=model_name)
@app.get('/models/{model_name}/configs', summary='get model configs', tags=['CRF models'])
async def get_configs(model_name: str):
"""
get the configs which a model is trained with
- **model_name**: name of the model
"""
if model_name not in crf_handler.get_models_list():
raise HTTPException(status_code=404, detail=f'model not found: {model_name}')
configs = crf_handler.get_config(model_name=model_name)
configs['paths_to_trainsets'] = [f'/trainset/{os.path.basename(item)}' for item in configs['paths_to_trainsets']]
return configs
@app.post('/models/{model_name}/tag', summary='tag a sentence', tags=['CRF models'])
async def tag_sentence(model_name: str, sentence: str = Body(..., embed=True)):
"""
tag a sentence using a given model.
- **model_name**: name of the model
- **sentence**: the string to tag
"""
tagged_sentence = crf_handler.tag_sentence(model_name=model_name, sentence=sentence)
return {'sentence': sentence, 'tagged_sentence': tagged_sentence}
@app.post('/models/{model_name}/train', summary='train a model', tags=['CRF models'])
async def train_model(model_name: str, trainset: List[UploadFile] = File(...), features: str = Form(None)):
"""
train a model using the given trainset and features. If the model already exists, this will replace the model with
the new one.
- **model_name**: name of the model
- **trainset**: a csv file formatted for CRF model training. The request can have many `trainset` fields
(if you want to use multiple trainsets)
- **features**: a comma-separated list of feature names e.g. city | city,number,currency
"""
paths_to_trainsets = []
for file in trainset:
trainset_path = save_trainset(file_name=file.filename, content=await file.read())
paths_to_trainsets.append(trainset_path)
if features:
features = features.split(',')
crf_handler.train_model(model_name=model_name, paths_to_trainsets=paths_to_trainsets, features=features)
return {
'configs': crf_handler.get_config(model_name=model_name),
'reports': crf_handler.get_model_stats(model_name=model_name)}
@app.get('/features/data', summary='show features data', tags=['Features'])
async def get_features_data():
"""get features data used in feature_manager"""
return feature_manager.feature_data
@app.post('/features/data/add', summary='add feature data', tags=['Features'])
async def add_features_data(feature_info: FeatureData):
"""add a set of phrases under a label to features data"""
feature_manager.add_feature_data(label=feature_info.label, phrases=set(feature_info.phrases))
return {'response': f'added {feature_info.label!r} to the features data'}
@app.post('/features/add/lookup', summary='add lookup feature', tags=['Features'])
async def add_lookup_feature(feature: LookupFeature):
"""
adds a lookup feature to the feature manager
Note: you must add feature data before adding a lookup feature
"""
if feature.label not in feature_manager.feature_data:
raise HTTPException(
status_code=404, detail=f'{feature.label!r} not found in feature data. Consider adding them first')
feature_manager.add_lookup_handler(label=feature.label, phrase_max_length=feature.phrase_max_length)
return {'response': f'added a handler to handle {feature.label!r} lookup feature'}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment