Created
November 5, 2018 13:47
-
-
Save hkristen/971af4233952c506b8cfbcfc007c52c1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This webapp is based on Healthy or Not (https://github.com/nikhilno1/healthy-or-not) -> Thanks! | |
from starlette.applications import Starlette | |
from starlette.responses import JSONResponse, HTMLResponse, RedirectResponse | |
from fastai import * | |
from fastai.vision import * | |
import torch | |
from io import BytesIO | |
import uvicorn | |
import aiohttp | |
import os | |
import base64 | |
from PIL import Image as PILImage | |
async def get_bytes(url): | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
return await response.read() | |
def encode(img): | |
img = (image2np(img.data) * 255).astype('uint8') | |
pil_img = PILImage.fromarray(img) | |
buff = BytesIO() | |
pil_img.save(buff, format="JPEG") | |
return base64.b64encode(buff.getvalue()).decode("utf-8") | |
#create data object from csv file with only one image per class (easy workaround to get all classes right) | |
data_path = 'data/ImageCLEF2011/' | |
images_clef2013_pd_sheet_one_class = pd.read_csv(data_path + 'images_clef2013_pd_one_class.csv') | |
data_placeholder = ImageDataBunch.from_df(data_path, images_clef2013_pd_sheet_one_class, fn_col=2, label_col=0, ds_tfms=get_transforms(), size=224) | |
#initalize pretrained fastai model | |
learner = create_cnn(data_placeholder, models.resnet34) | |
learner.load('leaf_types_stage_1') | |
#run inference on CPU, not GPU | |
defaults.device = torch.device('cpu') | |
#initialize app | |
app = Starlette() | |
@app.route("/upload", methods=["POST"]) | |
async def upload(request): | |
data = await request.form() | |
bytes = await (data["file"].read()) | |
return predict_image_from_bytes(bytes) | |
@app.route("/classify-url", methods=["GET"]) | |
async def classify_url(request): | |
bytes = await get_bytes(request.query_params["url"]) | |
return predict_image_from_bytes(bytes) | |
def predict_image_from_bytes(bytes): | |
img = open_image(BytesIO(bytes)) | |
pred_class,pred_idx,outputs = learner.predict(img) | |
confidence = outputs[pred_idx].item() | |
img_data = encode(img) | |
return HTMLResponse( | |
""" | |
<html> | |
<body> | |
<p>Prediction: <b>%s</b></p> | |
<p>Confidence: %s</p> | |
</body> | |
<figure class="figure"> | |
<img src="data:image/png;base64, %s" class="figure-img img-thumbnail input-image"> | |
</figure> | |
</html> | |
""" %(pred_class.upper(), confidence, img_data)) | |
@app.route("/") | |
def form(request): | |
return HTMLResponse( | |
""" | |
<h1>Which plant leaf is this??</h1> | |
<p>Find out to what plant this leaf belongs to (based on https://www.imageclef.org/2013/plant)</p><br> | |
<p>Upload image or specify URL.</p><br> | |
<form action="/upload" method="post" enctype="multipart/form-data"> | |
<u>Select image to upload:</u><br><p> | |
1. <input type="file" name="file"><br><p> | |
2. <input type="submit" value="Upload and analyze image"> | |
</form> | |
<br> | |
<strong>OR</strong><br><p> | |
<u>Submit a URL:</u> | |
<form action="/classify-url" method="get"> | |
1. <input type="url" name="url" size="60"><br><p> | |
2. <input type="submit" value="Fetch and analyze image"> | |
</form> | |
""") | |
@app.route("/form") | |
def redirect_to_homepage(request): | |
return RedirectResponse("/") | |
port = int(os.environ.get("PORT", 8008)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment