Skip to content

Instantly share code, notes, and snippets.

@iaverypadberg
Created November 8, 2022 13:35
Show Gist options
  • Save iaverypadberg/da12e8894ee57ed60b51b8c87ba5c440 to your computer and use it in GitHub Desktop.
Save iaverypadberg/da12e8894ee57ed60b51b8c87ba5c440 to your computer and use it in GitHub Desktop.
A python script for running tflite inference on a folder of images
from sqlite3 import adapt
import numpy as np
from PIL import Image
import glob
import cv2
import numpy as np
import tflite_runtime.interpreter as tflite
from pycoral.adapters import detect
from pycoral.utils.dataset import read_label_file
from PIL import Image
from PIL import ImageDraw
import collections
def draw_object(draw, obj):
"""Draws detection candidate on the image.
Args:
draw: the PIL.ImageDraw object that draw on the image.
obj: The detection candidate.
"""
draw.rectangle(obj.bbox, outline='red')
draw.text((obj.bbox[0], obj.bbox[3]), obj.label, fill='#0000')
draw.text((obj.bbox[0], obj.bbox[3] + 10), str(obj.score), fill='#0000')
Object = collections.namedtuple('Object', ['label', 'score', 'bbox'])
shape = 160
model = '/home/isaac/Desktop/spaghetti/spag_models/spaghetti_spruce_2_train23.tflite'
files=glob.glob("/home/isaac/Desktop/dataset/320_test/*.jpg")
images=np.zeros([len(files),shape,shape,3]).astype(np.uint8)
for i in range(len(files)):
look=cv2.imread(files[i])
images[i] = cv2.resize(look,(shape,shape))
images[i]=cv2.cvtColor(images[i],cv2.COLOR_BGR2RGB)
interpreter = tflite.Interpreter(model)
interpreter_input_details = interpreter.get_input_details()
interpreter.allocate_tensors()
interpreter_output_details = interpreter.get_output_details()
labels = read_label_file("/home/isaac/Desktop/spaghetti/spaghetti_meta/spruce_2_labels.txt")
count = 0
for image in files:
img = Image.open(image).convert('RGB')
draw = ImageDraw.Draw(img)
objects_by_label = dict()
word_count = 0
interpreter.set_tensor(interpreter_input_details[0]['index'],np.expand_dims(img, axis=0))
interpreter.invoke()
objs = detect.get_objects(interpreter,score_threshold=.5)
for obj in objs:
bbox = [obj.bbox.xmin, obj.bbox.ymin, obj.bbox.xmax, obj.bbox.ymax]
label = labels.get(obj.id, '')
objects_by_label.setdefault(label,[]).append(Object(label, obj.score, bbox))
for label, objects in objects_by_label.items():
for obj in objects:
draw_object(draw, obj)
# img.show()
img.save(f"/home/isaac/Desktop/dataset/320_test_complete/result_{count}.jpg")
count+=1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment