Skip to content

Instantly share code, notes, and snippets.

@phil8192
Created April 3, 2020 20:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phil8192/83c254dddc0c626b3106ce78e79a7469 to your computer and use it in GitHub Desktop.
Save phil8192/83c254dddc0c626b3106ce78e79a7469 to your computer and use it in GitHub Desktop.
class FedAvg(BaseEstimator, ClassifierMixin):
def __init__(self,
n_runners=1,
sample_size=1,
rounds=1,
combine='weighted',
partition_params={
'scheme': 'uniform'
},
runner_hyperparams={
'epochs': 1,
'lr': 0.15,
'batch_size': 0
},
intercept_init=None,
coef_init=None):
self.intercept_ = intercept_init
self.coef_ = coef_init
self.n_runners = n_runners
self.sample_size = sample_size
self.rounds = rounds
self.combine = combine
self.partition_params = partition_params
self.runner_hyperparams = runner_hyperparams
self.models = []
def _collect_models(self, runners, N):
r_intercepts, r_coefs, r_weights = [], [], []
self.models = []
for runner in random.sample(runners, k=self.sample_size):
r_model = runner.optimise(self.intercept_, self.coef_, self.runner_hyperparams)
self.models.append(r_model)
r_intercepts.append(r_model.intercept_)
r_coefs.append(r_model.coef_)
r_weights.append(runner.dataset_size()/N if self.combine == 'weighted' else 1/self.sample_size)
return r_intercepts, r_coefs, r_weights
# FedAvg algo.
def fit(self, X, y):
if self.intercept_ is None or self.coef_ is None:
features = X.shape[1]
self.intercept_ = np.zeros(1)
self.coef_ = np.zeros((1, features))
N = X.shape[0]
runners = init_runners(X_train, y_train, self.n_runners, **self.partition_params)
for _ in range(self.rounds):
r_intercepts, r_coefs, r_weights = self._collect_models(runners, N)
self.intercept_ = np.average(r_intercepts, axis=0, weights=r_weights)
self.coef_ = np.average(r_coefs, axis=0, weights=r_weights)
self.global_model = set_weights(self.intercept_, self.coef_, np.unique(y))
return self
def predict(self, X):
if not hasattr(self, 'global_model'):
raise Exception("model not trained")
return self.global_model.predict(X)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment