Skip to content

Instantly share code, notes, and snippets.

@mansiparashar
Created July 2, 2021 03:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mansiparashar/ee30c56e7361895d2f6b12d1597b27fc to your computer and use it in GitHub Desktop.
Save mansiparashar/ee30c56e7361895d2f6b12d1597b27fc to your computer and use it in GitHub Desktop.
def split_image_data_realwd(data, labels, n_clients=100, verbose=True):
'''
Splits (data, labels) among 'n_clients s.t. every client can holds any number of classes which is trying to simulate real world dataset
Input:
data : [n_data x shape]
labels : [n_data (x 1)] from 0 to n_labels(10)
n_clients : number of clients
verbose : True/False => True for printing some info, False otherwise
Output:
clients_split : splitted client data into desired format
'''
def break_into(n,m):
'''
return m random integers with sum equal to n
'''
to_ret = [1 for i in range(m)]
for i in range(n-m):
ind = random.randint(0,m-1)
to_ret[ind] += 1
return to_ret
#### constants ####
n_classes = len(set(labels))
classes = list(range(n_classes))
np.random.shuffle(classes)
label_indcs = [list(np.where(labels==class_)[0]) for class_ in classes]
#### classes for each client ####
tmp = [np.random.randint(1,10) for i in range(n_clients)]
total_partition = sum(tmp)
#### create partition among classes to fulfill criteria for clients ####
class_partition = break_into(total_partition, len(classes))
#### applying greedy approach first come and first serve ####
class_partition = sorted(class_partition,reverse=True)
class_partition_split = {}
#### based on class partition, partitioning the label indexes ###
for ind, class_ in enumerate(classes):
class_partition_split[class_] = [list(i) for i in np.array_split(label_indcs[ind],class_partition[ind])]
# print([len(class_partition_split[key]) for key in class_partition_split.keys()])
clients_split = []
count = 0
for i in range(n_clients):
n = tmp[i]
j = 0
indcs = []
while n>0:
class_ = classes[j]
if len(class_partition_split[class_])>0:
indcs.extend(class_partition_split[class_][-1])
count+=len(class_partition_split[class_][-1])
class_partition_split[class_].pop()
n-=1
j+=1
##### sorting classes based on the number of examples it has #####
classes = sorted(classes,key=lambda x:len(class_partition_split[x]),reverse=True)
if n>0:
raise ValueError(" Unable to fulfill the criteria ")
clients_split.append([data[indcs], labels[indcs]])
# print(class_partition_split)
# print("total example ",count)
def print_split(clients_split):
print("Data split:")
for i, client in enumerate(clients_split):
split = np.sum(client[1].reshape(1,-1)==np.arange(n_labels).reshape(-1,1), axis=1)
print(" - Client {}: {}".format(i,split))
print()
if verbose:
print_split(clients_split)
clients_split = np.array(clients_split)
return clients_split
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment