Skip to content

Instantly share code, notes, and snippets.

@lakshay-arora
Created June 28, 2020 17:00
Show Gist options
  • Save lakshay-arora/a6513b0e6035cc073425d6ec52caeae2 to your computer and use it in GitHub Desktop.
Save lakshay-arora/a6513b0e6035cc073425d6ec52caeae2 to your computer and use it in GitHub Desktop.
# get directory function in get images file
def get_directory(url):
return "URL_" + str(url.replace("/","_"))
# get class of all the images present in the directory
def get_category(model, imagenet_class_mapping, image_path):
with open(image_path, 'rb') as file:
image_bytes = file.read()
transformed_image = transform_image(image_bytes=image_bytes)
outputs = model.forward(transformed_image)
_, category = outputs.max(1)
predicted_idx = str(category.item())
return imagenet_class_mapping[predicted_idx]
# It will create a dictionary of the image path and the predicted class
# we will use that dictionary to generate the html file.
def get_prediction(model, imagenet_class_mapping, path_to_directory):
files = glob.glob(path_to_directory+'/*')
image_with_tags = {}
for image_file in files:
image_with_tags[image_file] = get_category(model, imagenet_class_mapping, image_path=image_file)[1]
return image_with_tags
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment