Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
# get directory function in get images file
def get_directory(url):
return "URL_" + str(url.replace("/","_"))
# get class of all the images present in the directory
def get_category(model, imagenet_class_mapping, image_path):
with open(image_path, 'rb') as file:
image_bytes = file.read()
transformed_image = transform_image(image_bytes=image_bytes)
outputs = model.forward(transformed_image)
_, category = outputs.max(1)
predicted_idx = str(category.item())
return imagenet_class_mapping[predicted_idx]
# It will create a dictionary of the image path and the predicted class
# we will use that dictionary to generate the html file.
def get_prediction(model, imagenet_class_mapping, path_to_directory):
files = glob.glob(path_to_directory+'/*')
image_with_tags = {}
for image_file in files:
image_with_tags[image_file] = get_category(model, imagenet_class_mapping, image_path=image_file)[1]
return image_with_tags
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.