Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rahulremanan/894d58aa3b9ee70c8dca52ad5729c1c3 to your computer and use it in GitHub Desktop.
Save rahulremanan/894d58aa3b9ee70c8dca52ad5729c1c3 to your computer and use it in GitHub Desktop.
Helper functions to map the saliency filter output to FairFace data
def saliency_to_image(input_image, s_point, images_list, padding=0, image_mode='horizontal'):
if image_mode == 'horizontal':
s_idx = 0
elif image_mode == 'vertical':
s_idx = 1
else:
raise ValueError('Unsupported image mode. \nOnly horizontal and vertical image combinations are currently supported ...')
for i in range(len(images_list)):
if len(s_point)>1:
warnings.warn('Only reading the first saliency point. \nParsing one saliency point is currently supported ...')
s_image_idx = 0
if (input_image.size[s_idx]-s_point[0][s_idx]) < (
input_image.size[s_idx]-(i*input_image.size[s_idx]/len(images_list))):
s_image_idx = i
if s_image_idx < len(images_list):
return images_list[s_image_idx]
else:
return images_list[-1]
def saliency_point_to_info(input_file, image_files, model, df, image_mode='horizontal'):
sp_ = model.get_output(Path(input_file))['salient_point']
img_ = Image.open(input_file)
s_img_file = saliency_to_image(img_, sp_, image_files, image_mode=image_mode)
try:
s_filename = s_img_file.absolute()
except AttributeError:
s_filename = str(s_img_file)
sID = str(s_filename).split('/')[-1].replace('.jpg','')
s_info = img_info(df, int(sID)-1)
del img_
del s_img_file
del s_filename
del sID
return s_info, sp_
img_files = list(data_dir.glob("./*.jpg"))
images = [Image.open(x) for x in img_files]
img = join_images(images, col_wrap=2, img_size=(128, -1))
display(img)
img.save(f"{output_dir}/{filename}_h.jpeg", "JPEG")
model.plot_img_crops_using_img(img, topK=5, col_wrap=6)
plt.savefig(f"{output_dir}/{filename}_h_sm.jpeg",bbox_inches="tight")
saliency_info,sp = saliency_point_to_info(f"{output_dir}/{filename}_h.jpeg", img_files, model, img_labels, image_mode='horizontal')
encoded_labels(saliency_info['race'],labels_encoder)
decoded_labels(encoded_labels(saliency_info['race'],labels_encoder),labels_encoder)
print(saliency_info,sp)
images = [Image.open(x) for x in img_files]
img = join_images(images, col_wrap=1, img_size=(128, -1))
display(img)
img.save(f"{output_dir}/{filename}_v.jpeg", "JPEG")
model.plot_img_crops_using_img(img, topK=5, col_wrap=6)
plt.savefig(f"{output_dir}/{filename}_v_sm.jpeg",bbox_inches="tight")
salient_point = model.get_output(Path(f"{output_dir}/{filename}_v.jpeg"))['salient_point']
print(salient_point)
saliency_image = saliency_to_image(img, salient_point, img_files, image_mode='vertical')
saliency_filename = saliency_image.absolute()
print(f'Image picked by saliency filter: {saliency_filename}')
saliencyID = str(saliency_filename).split('/')[-1].replace('.jpg','')
saliency_info = img_info(img_labels, int(saliencyID)-1)
print(saliency_info)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment