Skip to content

Instantly share code, notes, and snippets.

@EnisBerk
Last active December 28, 2021 23:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EnisBerk/09414c8834a9e171c9f3bb921189aabb to your computer and use it in GitHub Desktop.
Save EnisBerk/09414c8834a9e171c9f3bb921189aabb to your computer and use it in GitHub Desktop.
def delete_samples_by_shapley(dataset,
shapley_values,
percentage,
strategy='best2worst'):
'''
select samples according to shapley
Args:
dataset: audio_dataset
shapley_values: shapley values dict {'Clip Path': shapley_value}
percentage: percentage of samples to select
strategy: 'best2worst' or 'worst2best', which one to select
'''
assert percentage > 0 and percentage <= 1
if strategy == 'best2worst':
shapley_values = sorted(shapley_values.items(),
key=lambda x: x[1],
reverse=True)
elif strategy == 'worst2best':
shapley_values = sorted(shapley_values.items(), key=lambda x: x[1])
else:
raise ValueError('strategy should be best2worst or worst2best')
len_shapley = len(shapley_values)
# delete bottom (1-percentage)% samples
print('Delete bottom {}% samples'.format((1-percentage) * 100))
print('database size', len(dataset))
for (file_path,shapley_value) in shapley_values[len_shapley - 1:int(len_shapley *
(percentage)):-1]:
del dataset[file_path]
print('database size after filtering', len(dataset))
return dataset
shapley_values = {}
dataset={}
for i in range(1,11):
shapley_values[str(i)]=i/10
dataset[str(i)] = None
# shapley_values
delete_samples_by_shapley(dataset,shapley_values,0.2)
# How I load the data
def load_shapley(shapley_csv, sample_id_csv, filter_key):
import pandas as pd
# shapley_csv = '/scratch/arsyed/shapley/alaska-shapley.csv'
# filter_key = 'shap_songbird'
shapley_values = utils.read_csv(shapley_csv)
sample_ids = utils.read_csv(sample_id_csv)
sample_id2clip_path = {}
path2shapely = {}
for row in sample_ids:
sample_id2clip_path[row['sample_id']] = row['file_path']
for row in shapley_values:
path2shapely[sample_id2clip_path[row['sample_id']]] = float(
row[filter_key])
return path2shapely
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment