Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rahulremanan/f94d0ee7ce18414817658a8e92911823 to your computer and use it in GitHub Desktop.
Save rahulremanan/f94d0ee7ce18414817658a8e92911823 to your computer and use it in GitHub Desktop.
Numerical encoding of the FairFace labels
twitter_saliency_eval_dir = f'{img_dir}//Twitter_saliency'
if not os.path.exists(twitter_saliency_eval_dir):
print(f'No outputs directory: {twitter_saliency_eval_dir} found ...')
execute_in_shell([f'mkdir {twitter_saliency_eval_dir}'])
print(f'Created outputs directory: {twitter_saliency_eval_dir}')
labels_encoder_file = f'{twitter_saliency_eval_dir}/labels_encoder.json'
if os.path.exists(labels_encoder_file):
with open(labels_encoder_file) as f:
labels_encoder = json.loads(f.read())
print(labels_encoder)
print(f'Loaded labels encoder data from: {labels_encoder_file} ...')
else:
print(f'No saved labels encoder data: {labels_encoder_file} ...')
labels_encoder = {}
for i, race in enumerate(sorted(list(set(img_labels['race'].values)))):
labels_encoder.update({race: i})
print(labels_encoder)
with open(labels_encoder_file, 'w+') as f:
json.dump(labels_encoder, f)
print(f'Saved labels encoder data to: {labels_encoder_file} ...')
def encoded_labels(input_label, labels_encoder):
return labels_encoder[input_label]
def decoded_labels(input_label, labels_encoder):
return list(labels_encoder.keys())[list(labels_encoder.values()).index(input_label)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment