Skip to content

Instantly share code, notes, and snippets.

@KennethTM
Created December 16, 2022 13:46
Show Gist options
  • Save KennethTM/77f4306e0e1ee907a4e36cd0ce8a36d6 to your computer and use it in GitHub Desktop.
Save KennethTM/77f4306e0e1ee907a4e36cd0ce8a36d6 to your computer and use it in GitHub Desktop.
Shiny application with image classification model using ONNX
#Source
#https://github.com/ranikay/shiny-reticulate-app
VIRTUALENV_NAME = 'onnx_runtime'
if (Sys.info()[['user']] == 'shiny'){
# Running on shinyapps.io
Sys.setenv(PYTHON_PATH = 'python3')
Sys.setenv(VIRTUALENV_NAME = VIRTUALENV_NAME)
Sys.setenv(RETICULATE_PYTHON = paste0('/home/shiny/.virtualenvs/', VIRTUALENV_NAME, '/bin/python'))
} else {
# Running locally
options(shiny.port = 7450)
Sys.setenv(PYTHON_PATH = 'python')
Sys.setenv(VIRTUALENV_NAME = VIRTUALENV_NAME)
}
library(shiny)
library(OpenImageR)
library(reticulate)
#Define UI with upload button and display for image
ui <- fluidPage(
titlePanel("Image classification app"),
sidebarLayout(
sidebarPanel(
fileInput("file", "Choose a file", accept = c('image/png', 'image/jpeg'))
),
mainPanel(
imageOutput("image", height = "224px", width = "224px"),
textOutput("image_label")
)
)
)
server <- function(input, output) {
#Create virtual env and install dependencies
virtualenv_dir = Sys.getenv('VIRTUALENV_NAME')
python_path = Sys.getenv('PYTHON_PATH')
PYTHON_DEPENDENCIES = c('onnxruntime')
virtualenv_create(envname = virtualenv_dir, python = python_path)
virtualenv_install(virtualenv_dir, packages = PYTHON_DEPENDENCIES,
ignore_installed=TRUE)
use_virtualenv(virtualenv_dir, required = TRUE)
#Load model using onnxruntime
ort <- import("onnxruntime")
model_path <- "model.onnx"
ort_sess <- ort$InferenceSession(model_path)
#Read classes from text file
classes <- readLines("classes.txt")
observeEvent(input$file, {
inFile <- input$file
if (is.null(inFile)){
return()
}
#Read input image
img <- readImage(input$file$datapath)
img_h <- dim(img)[2]
img_w <- dim(img)[1]
#Determine size for preprocessing (similar to Pytorch preprocessing steps)
init_size <- 256
final_size <- 224
if(img_h > img_w){
new_h <- init_size * (img_h/img_w)
new_w <- init_size
}else{
new_h <- init_size
new_w <- init_size / (img_h/img_w)
}
#Resize image
img_resize <- resizeImage(img, width = new_w, height = new_h,
method="bilinear", normalize_pixels = TRUE)
#Center crop image
img_crop <- cropImage(img_resize, type = "equal_spaced",
new_width = final_size, new_height = final_size)
#Re-arrange array axes
img_perm <- aperm(img_crop, c(3, 1, 2))
#Expand array dimension
dim(img_perm) <- c(1, dim(img_perm))
#Prepare input images in Python dictionary
image_dict <- dict(list('input' = np_array(img_perm, dtype="float32")))
#Get model predictions
predictions <- ort_sess$run(py_none(), image_dict)
predictions <- unlist(predictions)
#Get predicted class index, probability, and label
class_ind <- which.max(predictions)
class_prob <- max(predictions)
class_label <- classes[class_ind]
#Render result as text
output$image_label <- renderText(paste0("Predicted class: ",
class_label,
" (",
round(class_prob*100, digits = 1),
"%)"))
#Write image to temporary file and return to UI display
outfile <- tempfile(fileext='.png')
writeImage(img_crop, outfile)
output$image <- renderImage(list(src=outfile), deleteFile=TRUE)
})
}
#Run the application
shinyApp(ui = ui, server = server)
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment