Created
July 2, 2021 03:22
-
-
Save mansiparashar/ee30c56e7361895d2f6b12d1597b27fc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def split_image_data_realwd(data, labels, n_clients=100, verbose=True): | |
''' | |
Splits (data, labels) among 'n_clients s.t. every client can holds any number of classes which is trying to simulate real world dataset | |
Input: | |
data : [n_data x shape] | |
labels : [n_data (x 1)] from 0 to n_labels(10) | |
n_clients : number of clients | |
verbose : True/False => True for printing some info, False otherwise | |
Output: | |
clients_split : splitted client data into desired format | |
''' | |
def break_into(n,m): | |
''' | |
return m random integers with sum equal to n | |
''' | |
to_ret = [1 for i in range(m)] | |
for i in range(n-m): | |
ind = random.randint(0,m-1) | |
to_ret[ind] += 1 | |
return to_ret | |
#### constants #### | |
n_classes = len(set(labels)) | |
classes = list(range(n_classes)) | |
np.random.shuffle(classes) | |
label_indcs = [list(np.where(labels==class_)[0]) for class_ in classes] | |
#### classes for each client #### | |
tmp = [np.random.randint(1,10) for i in range(n_clients)] | |
total_partition = sum(tmp) | |
#### create partition among classes to fulfill criteria for clients #### | |
class_partition = break_into(total_partition, len(classes)) | |
#### applying greedy approach first come and first serve #### | |
class_partition = sorted(class_partition,reverse=True) | |
class_partition_split = {} | |
#### based on class partition, partitioning the label indexes ### | |
for ind, class_ in enumerate(classes): | |
class_partition_split[class_] = [list(i) for i in np.array_split(label_indcs[ind],class_partition[ind])] | |
# print([len(class_partition_split[key]) for key in class_partition_split.keys()]) | |
clients_split = [] | |
count = 0 | |
for i in range(n_clients): | |
n = tmp[i] | |
j = 0 | |
indcs = [] | |
while n>0: | |
class_ = classes[j] | |
if len(class_partition_split[class_])>0: | |
indcs.extend(class_partition_split[class_][-1]) | |
count+=len(class_partition_split[class_][-1]) | |
class_partition_split[class_].pop() | |
n-=1 | |
j+=1 | |
##### sorting classes based on the number of examples it has ##### | |
classes = sorted(classes,key=lambda x:len(class_partition_split[x]),reverse=True) | |
if n>0: | |
raise ValueError(" Unable to fulfill the criteria ") | |
clients_split.append([data[indcs], labels[indcs]]) | |
# print(class_partition_split) | |
# print("total example ",count) | |
def print_split(clients_split): | |
print("Data split:") | |
for i, client in enumerate(clients_split): | |
split = np.sum(client[1].reshape(1,-1)==np.arange(n_labels).reshape(-1,1), axis=1) | |
print(" - Client {}: {}".format(i,split)) | |
print() | |
if verbose: | |
print_split(clients_split) | |
clients_split = np.array(clients_split) | |
return clients_split |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment