Skip to content

Instantly share code, notes, and snippets.

@chilang
Created May 25, 2020 17:06
Show Gist options
  • Save chilang/d844b5205f41ee63ec94f41c621d4762 to your computer and use it in GitHub Desktop.
Save chilang/d844b5205f41ee63ec94f41c621d4762 to your computer and use it in GitHub Desktop.
def compare_f1(hyperparams, rounds=1):
local_mean_f1_scores = []
local_max_f1_scores = []
fed_avg_f1_scores = []
for i in range(0, rounds):
fed_avg = FedAvg(**hyperparams)
fed_avg.fit(X_train, y_train)
preds = fed_avg.predict(X_test)
fed_avg_f1_scores.append(f1_score(y_test, preds, average='weighted'))
tmp = []
for model in fed_avg.models:
local_pred = model.predict(X_test)
tmp.append(f1_score(y_test, local_pred, average='weighted'))
local_mean_f1_scores.append(np.mean(tmp))
local_max_f1_scores.append(np.max(tmp))
return np.mean(local_mean_f1_scores), np.mean(local_max_f1_scores), np.mean(fed_avg_f1_scores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment