Skip to content

Instantly share code, notes, and snippets.

@MuhammadAshiqAmeer
Created July 21, 2023 08:57
Show Gist options
  • Save MuhammadAshiqAmeer/0f56e9ca9317cab887291a6506f9b4a9 to your computer and use it in GitHub Desktop.
Save MuhammadAshiqAmeer/0f56e9ca9317cab887291a6506f9b4a9 to your computer and use it in GitHub Desktop.
Introduction to federated learning and it evolution, federated learning algorithms, the process and different frameworks.

What is Federated Learning?

Federated Learning is an innovative approach that prioritizes privacy and security in the context of distributed machine learning. It allows multiple entities, such as devices or organizations, to collaboratively train a global machine learning model while keeping their raw data locally stored and not shared or transmitted. This unique characteristic of federated learning makes it particularly appealing in scenarios where data privacy and security are of paramount importance. A common FL setup is shown in figure below.

Image

One of the key advantages of federated learning is that it allows organizations or devices to leverage their local data and expertise without the need to share sensitive information. This is particularly beneficial in scenarios where data ownership, regulatory compliance, or confidentiality concerns are critical. Federated learning enables entities to collaborate on model training while maintaining control over their data, mitigating the risks associated with data breaches or unauthorized access.

The Evolution and the cause:

Machine Learning (ML) has revolutionized various industries, with applications ranging from image analysis to data sequencing and natural language processing (NLP) to autonomous systems. However, one of the significant challenges in traditional ML is generalizability, as domain shifts often occur between industry-specific data and general-purpose public datasets used for training. This challenge arises due to the need for large amounts of heterogeneous data to train ML models. To address this, Distributed computing systems have emerged, where algorithms are implemented and executed on multiple nodes, utilizing data aggregated from various sources on a centralized server. This approach accelerates the training process and enables parallel training, optimizing the overall time required for machine learning tasks.

However, the centralization of data aggregated from different sources in distributed computing systems raises concerns about data privacy. Data leakage can have severe consequences, making data privacy a primary concern in various industries. To address this challenge, Federated Learning (FL) has been proposed as a decentralized machine-learning approach that involves individual parties training a shared model with their private data and only exchanging weight updates with a central server.

The Process:

In a typical Federated Learning setup, there are two main components other than in a traditional training pipeline:

  • Collaborator: Collaborator refers to a client that has access to the local training, validation, and test datasets. The design of the federated learning setup ensures that the collaborator is the only component that can access the local data, and the local dataset should never leave the collaborator's environment.

  • Aggregator: An aggregator is a component, often combined with a parameter server, that receives locally tuned models from collaborators in the federated learning setup. The aggregator then combines these locally tuned models into a new global model. The aggregation will be based on aggregation algorithms designed or chosen based on the problem and data distributions.

Also, a federation Round is a defined interval, typically in terms of training steps, during which an aggregation is performed in the federated learning setup. Collaborators may perform local training on the model for multiple epochs or even partial epochs within a single training round, allowing for flexibility in the training process.

During the FL process, each client trains its local model using its own dataset, without sharing the raw data with the global server or other clients. The clients update the weights of their local model based on their local dataset and transmit these updated weights back to the global server. The global server then aggregates the received weights from all the clients to update the global model. This aggregation process typically involves techniques such as weighted averaging or other aggregation strategies to combine the model updates from different clients.

The process of Federated Learning can be organized as:

  • Model initialization: The global model is customized to solve a specific machine learning problem and is initialized on the global server using randomization or pre-existing weights.
  • Client model initialization: A copy of the initialized model is dispatched to the clients for local training.
  • Local training: Each client trains its local version of the model on its private dataset.
  • Weight update: After local training, the clients transmit their updated weights back to the global server.
  • Global model update: The global server aggregates the weights obtained from the clients and updates the global model.
  • Redistribute: The updated global model is redistributed to each client for the next round of training.
  • Repeat: Steps 3 to 6 constitute one round of Federated Learning. The process is repeated until the desired convergence is achieved.

The Federation Algorithm:

The aggregation of weights by the global server in federated learning is crucial for combining the model updates from the respective collaborators. Several standard aggregation algorithms have been proposed such as FedAvg, FedProx, FedOpt etc. FedAvg, short for Federated Averaging, is a commonly used algorithm that calculates the average of the gradients received from the collaborators after training the model locally, weighted according to the size of their data. This approach ensures that each collaborator's contribution is proportional to the amount of data they have.

Algorithm: FedAvg

Input: Client datasets D1, D2, ..., Dn Output: Global model M

  1. Initialize global model Mo
  2. For Round r = 1 to R do:
    1. Send Mr to all clients
    2. For each client k do:
      1. Receive Mr from server
      2. Sample mini-batch Bm from Dk
      3. Compute local gradient ∇f_k(Mr, Bk)
      4. Send ∇f_k(Mr, Bk) to server
    3. Aggregate gradients from k clients: ∇F_r ← Σ(i=1 to k) (nk / n) ∇f_k(Mr, Bi), nk size of dataset at client k
    4. Update global model: Mr+1 ← Mr - η ∇F_r
  3. Return MR

These algorithms have been designed to address specific challenges and limitations of federated learning, such as data heterogeneity, communication constraints, and non-i.i.d. data distributions. Additionally, there are many other federated learning algorithms that have been developed to adapt to different real-world data distributions and conditions, highlighting the ongoing research and advancements in this field.

Types of Federated Learning:

Federated learning can be classified into three types based on the nature of the data used for collaboration: Horizontal (the original form), Vertical, and Transfer federated learning.

  • Horizontal Federated Learning: In horizontal federated learning, different parties collaborate who have similar data sets but different data samples. Each party has data related to the same set of features, but with different samples. For example, multiple hospitals collaborating to build a model for predicting disease outbreaks based on patient data from their respective hospitals, where each hospital has data related to the same set of features such as patient age, gender, and medical history, but with different patient samples.

  • Vertical Federated Learning: In vertical federated learning, different parties collaborate who have different sets of features. Each party has data related to a specific set of features, and the goal is to build a model that can make predictions based on all of the available features. For example, a collaboration between a bank and an e-commerce platform to build a model for fraud detection, where the bank has data related to financial transactions and the e-commerce platform has data related to user behavior. The goal is to build a model that can make predictions based on both financial transaction data and user behavior data.

  • Transfer Federated Learning: Incorporating Transfer learning in FL. Transfer learning is a learning technique to provide solutions for cross-domain knowledge transfer. In transfer federated learning, different parties collaborate who have different data sets, but with some overlap. The parties may have data related to different features or samples, but some data may be shared. For example, a collaboration between two companies, one specializing in image data and the other in text data, to build a model for image captioning. The image company shares some of its image data, and the text company shares some of its text data, and the goal is to build a model that can generate captions for images using both image and text data.

Furthermore, Federated Learning (FL) can be classified into two types based on the implementation setting: cross-device and cross-silo.

  • Cross-device Federated Learning: Cross-device federated learning involves training machine learning models across multiple devices, such as smartphones, laptops, and tablets, while preserving the privacy of data on each device. This allows for leveraging a large amount of data for training without compromising individual users' privacy, as demonstrated by examples like Google Gboard keyboard and Apple QuickType keyboard with vocal classifier for "Hey Siri" .

  • Cross-Silo Federated Learning: Cross-silo federated learning involves training machine learning models across different departments or silos within an organization, without sharing the raw data. This allows organizations to harness the data from various departments to build more accurate models that can provide better insights, as seen in applications such as finance risk prediction for reinsurance, medical data segmentation, and pharmaceutical discovery, etc.

Federated Learning Frameworks:

In federated learning, numerous frameworks with unique features and capabilities are available, some of which are exclusively tailored for specific organizations. The Joint Imaging Platform (JIP) and MELLODY are examples of such frameworks. JIP is a strategic initiative within the German Cancer Consortium (DKTK) that seeks to establish a technical infrastructure for modern and distributed imaging research, focusing on utilizing machine learning methods for medical image processing to facilitate collaborations between clinical sites and support multicenter trials. MELLODY, which stands for Machine learning ledger orchestration for drug discovery, enhances predictive models and improves drug discovery efficiency by harnessing the world's largest collection of small molecules. The platform's ten pharmaceutical partners contribute over a billion highly private and competitive data points, including hundreds of terabytes of image data. Some other flexible FL frameworks include:

  • TensorFlow Federated (TFF): Developed by Google, TFF is an open-source framework that is primarily used for simulation, where a single machine simulates clients and servers. It provides APIs for defining federated computations and is widely used in FL research and development.

  • PySyft: Developed and owned by OpenMined, PySyft is an open-source framework that supports secure multi-party computation (MPC) and is designed for FL. It provides APIs for defining federated computations and supports differential privacy and homomorphic encryption, making it suitable for privacy-sensitive FL scenarios.

  • IBM Federated Learning: Developed by IBM, this framework provides APIs for defining federated computations and supports differential privacy and secure aggregation. It is designed for enterprise-level FL applications and is used by organizations for collaborative machine learning across distributed datasets.

  • NVIDIA Clara: Initially developed for medical image segmentation, to help healthcare institutions train machine learning models collaboratively without exposing patient data to unauthorized parties. It is specifically tailored for healthcare applications and provides secure and privacy-preserving FL capabilities.

  • Flower: Flower was developed by the Machine Learning team at Adap, a research lab based in Berlin, Germany. Flower is scalable, allowing researchers to apply it to various use cases, including healthcare, finance, and smart cities. It provides flexibility in defining federated computations and is known for its ease of use.

  • OpenFL: OpenFL is Intel's Federated learning platform built on a collaboration between Intel and the University of Pennsylvania to develop the Federated Tumor Segmentation (FeTS) platform. FeTS is a real-world medical federated learning platform with international collaborators. OpenFL is structured to help researchers and data scientists work and test algorithms and modify the FL tasks independently without deepening the framework understanding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment