Jekyll2022-12-31T07:41:27+00:00https://gather-ai.github.io/feed.xmlGathering.AITo writeGathering.AIIECBES 2022 - Conference Highlights2022-12-31T00:00:00+00:002022-12-31T00:00:00+00:00https://gather-ai.github.io/summaries/iecbes-2022<p>👋 Hi there. Welcome back to my page. This week, I had an opportunity to travel to Malaysia for the first time to physically attend and present at a reputed scientific conference, <a href="https://www.iecbes.org/">IECBES 2022</a>. That was a memorable event in my personal journey. Today, I will briefly summarize conference speeches and presentations in this highlights blog post.</p>
<h2 id="1-about-iecbes-2022">1. About IECBES 2022</h2>
<p>The IECBES stands for IEEE-EMBS Conference on Biomedical Engineering and Sciences, which is organized once every 2 years by The IEEE Engineering in Medicine and Biology Society (IEEE-EMBS) Malaysia Chapter. The 7th IECBES with the major theme of “Healthcare Personalisation: For the Future & Beyond” was held in Kuala Lumpur from December 7th to 9th December 2022.</p>
<p>Consistent with the theme of “Healthcare Personalisation: For the Future & Beyond”, the conference provided 6 keynote lectures and 7 invited speeches from leading academic scientists, along with 75 accepted papers, categorized into 11 tracks and one special session. Figure 1 illustrates the number of papers in each track.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/iecbes/numbers.jpg" />
<figcaption>Figure 1. Number of Accepted Papers in Each Track. </figcaption>
</figure>
<p><strong>Insight 1</strong>: As we can observe from Figure 1, Biomedical Signal Processing is the most attractive track, followed by Internet of Things in Biomedical Engineering and Healthcare, Biomedical Imaging and Image Processing, and Biomedical Instrumentation and Devices. Might be because we overcame the pandemic, the track Pre & Post COVID-19 Pandemic Response is less attractive and only has 2 accepted papers.</p>
<p>Next, let’s use a word cloud visualization of the titles of all accepted papers to grasp some ideas about these works. In Figure 2, the size of each word indicates its frequency.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/iecbes/word-cloud.png" />
<figcaption>Figure 2. Number of Accepted Papers in Each Track. </figcaption>
</figure>
<p><strong>Insight 2</strong>: From the above word cloud visualization, we can draw some comments:</p>
<ul>
<li>Classification, detection (actually also classification), and analysis are the most frequent tasks conducted.</li>
<li>The big size of words such as “deep learning”, or “neural network” shows the broad application of AI.</li>
<li>Images and signals are the most frequent data modalities used.</li>
<li>The small size of the word “device” indicates that there are not many works focused on hardware.</li>
</ul>
<h2 id="2-outstanding-papers">2. Outstanding Papers</h2>
<p>In this section, I will briefly introduce some outstanding papers that I found considerable. These papers also are in tracks I am interested in such as Biomedical Imaging and Image Processing, Biomedical Signal Processing, Cardiovascular and Respiratory Systems, Pre & Post COVID-19 Pandemic Response, and Special Session - Trends in Smarter Healthcare: AI for Images. Some of these papers have been recognized as the Top 10 best papers.</p>
<p><strong>Reconstruction of Fetal Head Surface from A Few 2D Ultrasound Images Tracked in 3D Space</strong><br />
by <em>Sandra Marcadent; Johann Hêches; Julien Favre; David Desseauve; Jean-Philippe Thiran</em><br />
in <em>Biomedical Imaging and Image Processing</em><br /></p>
<blockquote>
<p>In this pilot study, we present a novel method to reconstruct the fetal head surface, from a small set of tracked 2D ultrasound images around the transthalamic brain plane. Indeed, 3D visualization of the fetus’s prominent skull at the beginning of birth could help the obstetrician in decision-making to overcome dystocia, a delivery complication that results in labor obstruction. The use of 2D ultrasound images tracked in 3D would allow superimposing of the fetal head model to other reconstructed organs. However, fetal motion may affect the consistency of ultrasound images, in particular, if many frames are needed. Moreover, the fetal head is large at late pregnancy stages which causes occlusions in the ultrasound images. We thus propose and compare the performance of two different methods to reconstruct a fetal head surface from only 10 focused frames. Our best method achieves 1.6 mm of average reconstruction error in simulation based on an MRI dataset of 7 patients at 34-36 weeks of pregnancy.</p>
</blockquote>
<p><strong>Vector-Quantized Zero-Delay Deep Autoencoders for The Compression of Electrical Stimulation Patterns of Cochlear Implants Using STOI</strong><br />
by <em>Reemt Hinrichs; Felix Ortmann; Jörn Ostermann</em><br />
in <em>Biomedical Signal Processing</em><br /></p>
<blockquote>
<p>Cochlear implants (CIs) are battery-powered, surgically implanted hearing aids capable of restoring a sense of hearing in people suffering from moderate to profound hearing loss. Wireless transmission of audio from or to signal processors of CIs can be used to improve speech understanding and localization of CI users. Data compression algorithms can be used to conserve battery power in this wireless transmission. However, very low latency is a strict requirement, limiting severely the available source coding algorithms. Previously, instead of coding the audio, coding the electrical stimulation patterns of CIs was proposed to optimize the trade-off between bit rate, latency, and quality. In this work, a zero-delay deep autoencoder (DAE) for the coding of the electrical stimulation patterns of CIs is proposed. Combining for the first time bayesian optimization with numerically approximated gradients of anon-differential speech intelligibility measure for CIs, the short-time intelligibility measure (STOI), an optimized DAE architecture was found and trained that achieved equal or superior speech understanding at zero delays, outperforming well-known audio codecs. The DAE achieved reference vocoder STOI scores at 13.5 kbit/s compared to 33.6 kbit/s for Opus and 24.5 kbit/s for AMR-WB.</p>
</blockquote>
<p><strong>Performance of A Wireless Electrocardiogram System Based on Wi-Fi and BLE Technology</strong><br />
by <em>Nusrat Hassan Khan; S M Nafiul Hasan Joy; Fauzan Khairi Che Harun; Weng Howe Chan; Nurul Ashikin Abdul-Kadir; Moey Keith</em><br />
in <em>Biomedical Instrumentation and Devices</em><br /></p>
<blockquote>
<p>Wearable electrocardiogram (ECG) systems have increasingly been used in everyday life, breaking down the barriers that formerly existed only within hospitals. They allow for non-invasive continuous monitoring of a variety of heart parameters. The aim of this work is to investigate and assess the development of a user-friendly, mobile, and compact wearable ECG system for instantaneous recording. The work also presented the design of the ECG system with Autodesk EAGLE and Fusion 360 which has wireless connectivity via Bluetooth and Wi-Fi. The functionality of this ECG system is aided by the BMD101 cardio chip device, which is composed of an amplifier, filter, and 16-bit analog-to-digital converter. The results indicated a regular cardiac rhythm of 60 beats per minute (bpm), 120 bpm, and 180 bpm, respectively, along with the abnormal heart condition of ventricular tachycardia. Eventually, this study concluded with a list of key remaining obstacles as well as the potential for development in terms of result display and system software, both of which are vital for continued advancement.</p>
</blockquote>
<p><strong>Enhancing Deep Learning-based 3-lead ECG Classification with Heartbeat Counting and Demographic Data Integration</strong><br />
by <em>Khiem Le; Huy Hieu Pham; Thao Nguyen; Tu Nguyen; Cuong Do; Tien Ngoc Thanh</em><br />
in <em>Cardiovascular and Respiratory Systems</em><br /></p>
<blockquote>
<p>Nowadays, an increasing number of people are being diagnosed with cardiovascular diseases (CVDs), the leading cause of death globally. The gold standard for identifying these heart problems is via electrocardiogram (ECG). The standard 12-lead ECG is widely used in clinical practice and the majority of current research. However, using a lower number of leads can make ECG more pervasive as it can be integrated with portable or wearable devices. This article introduces two novel techniques to improve the performance of the current deep learning system for 3-lead ECG classification, making it comparable with models that are trained using standard 12-lead ECG. Specifically, we propose a multi-task learning scheme in the form of the number of heartbeats regression and an effective mechanism to integrate patient demographic data into the system. With these two advancements, we got classification performance in terms of F1 scores of 0.9796 and 0.8140 on two large-scale ECG datasets, i.e., Chapman and CPSC-2018, respectively, which surpassed current state-of-the-art ECG classification methods, even those trained on 12-lead data.</p>
</blockquote>
<p><strong>Image-To-Graph Transformation via Superpixel Clustering to Build Nodes in Deep Learning for Graph</strong><br />
by <em>Hong Seng Gan; Muhammad Hanif Ramlee; Asnida Abdul Wahab; Wan MahaniHafizah Wan Mahmud; De Rosal Ignatius Moses Setiadi</em><br />
in <em>Special Session - Trends in Smarter Healthcare: AI for Images</em><br /></p>
<blockquote>
<p>In recent years, convolutional neural networks (CNN) becomes the mainstream image-processing technique for numerous medical imaging tasks such as segmentation, classification, and detection. Nonetheless, CNN is limited to processing fixed-size input and demonstrates low generalizability to unseen features. Graph deep learning adopts graph concepts and properties to capture rich information from complex data structures. Graphs can effectively analyze the pairwise relationship between the target entities. Implementation of graph deep learning in medical imaging requires the conversion of grid-like image structure into a graph representation. To date, the conversion mechanism remains underexplored. In this work, image-to-graph conversion via clustering has been proposed. Locally grouped homogeneous pixels have been grouped into a superpixel, which can be identified as a node. Simple linear iterative clustering (SLIC) emerged as the suitable clustering technique to build superpixels as nodes for subsequent graph deep learning computation. The method was validated on the knee, cell, and membrane image datasets. SLIC has reported a Rand score of 0.92±0.015 and a Silhouette coefficient of 0.85±0.02 for the cell dataset, 0.62±0.02 (Rand score) and 0.61±0.07 (Silhouette coefficient) for the membrane dataset, and 0.82±0.025 (Rand score) and 0.67±0.02 (Silhouette coefficient) for knee dataset. Future works will investigate the performance of superpixel with enforcing connectivity as the prerequisite to develop graph deep learning for medical image segmentation.</p>
</blockquote>
<h2 id="3-closing">3. Closing</h2>
<p>IECBES is a rising conference and more and more attractive in the field of Biomedical Engineering and Sciences. However, the conference is still fledgling with an h-index of around 7, therefore, its accepted papers are not really high-quality. I hope the conference will continue to grow, attracting more outstanding research from all over the world in the next years.</p>
<p>Stay tuned for more content …</p>Gathering.AI👋 Hi there. Welcome back to my page. This week, I had an opportunity to travel to Malaysia for the first time to physically attend and present at a reputed scientific conference, IECBES 2022. That was a memorable event in my personal journey. Today, I will briefly summarize conference speeches and presentations in this highlights blog post.End-to-End Named Entity Recognition2022-12-10T00:00:00+00:002022-12-10T00:00:00+00:00https://gather-ai.github.io/tutorials/named-entity-recognition<p>👋 Hi there. Welcome back to my page. In the last half-decade, Natural Language Processing (NLP) applications appear more and more in industrial products or business processes, reaching the same popularity as Computer Vision. Therefore, this field is too big to ignore. In this post, we will first time talk about an important NLP problem, <a href="https://en.wikipedia.org/wiki/Named-entity_recognition">Named Entity Recognition</a> (NER) which is the task of tagging entities in text with their corresponding type. In addition, we will also build a simple AI-core web application using <a href="https://gradio.app/">Gradio</a> and Hugging Face <a href="https://huggingface.co/spaces">Spaces</a>.</p>
<h2 id="1-about-ner-and-nlp">1. About NER and NLP</h2>
<p>Where is NER in the big picture of NLP? NLP is a large field with lots of various tasks. However, every NLP task can be categorized into 4 main groups:</p>
<ul>
<li><strong>Text Classification</strong>: Similar to Image Classification, Text Classification is a task of assigning a set of predefined classes to a sequence (a sentence, paragraph, or whole document). Some of the most well-known examples of Text Classification include Sentiment Analysis, Topic Labeling, Language Detection, and Intent Detection.</li>
<li><strong>Text Tagging</strong>: Text Tagging or Text Labeling is a core Information Extraction task in which each unique word (token) in a sequence is classified using a pre-defined label set. Text Tagging has some exciting applications such as Named Entity Recognition, or Part-of-Speech Tagging.</li>
<li><strong>A mix of Text Classification and Tagging</strong>: Multi-task Learning was introduced many times on my page. This group of tasks is where a model is expected to classify a given sequence and tag every word of it simultaneously. Some examples of this group are Named Entity Recognition and Relation Extraction, Intent Detection and Slot Filling.</li>
<li><strong>Text Generation</strong>: Text generation is the task of generating text with the goal of appearing indistinguishable from the human-written text. This task has many wonderful applications such as (Abstractive) Document Summarization, Machine Translation, or Chatbot.</li>
</ul>
<h2 id="2-phoner-covid-19-dataset">2. PhoNER-COVID-19 Dataset</h2>
<p>In this tutorial, I will use the <a href="https://arxiv.org/abs/2104.03879v1">PhoNER-COVID-19</a> dataset, a dataset for recognizing COVID-19-related named entities in Vietnamese news, consisting of 35K entities over 10K sentences. The dataset includes 10 entity types with the aim of extracting key information related to COVID-19 patients, which are especially useful in downstream applications. In general, these entity types can be used in the context of not only the COVID-19 pandemic but also in other future epidemics:</p>
<table>
<thead>
<tr>
<th style="text-align: left">Entity Type</th>
<th style="text-align: left">Definition</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">PATIENT_ID</td>
<td style="text-align: left">Unique identifier of a COVID-19 patient in Vietnam. An PATIENT_ID annotation over “X” refers to as the X-th patient having COVID-19 in Vietnam.</td>
</tr>
<tr>
<td style="text-align: left">NAME</td>
<td style="text-align: left">Name of a patient or person who comes into contact with a patient.</td>
</tr>
<tr>
<td style="text-align: left">AGE</td>
<td style="text-align: left">Age of a patient or person who comes into contact with a patient.</td>
</tr>
<tr>
<td style="text-align: left">GENDER</td>
<td style="text-align: left">Gender of a patient or person who comes into contact with a patient.</td>
</tr>
<tr>
<td style="text-align: left">JOB</td>
<td style="text-align: left">Job of a patient or person who comes into contact with a patient.</td>
</tr>
<tr>
<td style="text-align: left">LOCATION</td>
<td style="text-align: left">Locations/places that a patient was presented at.</td>
</tr>
<tr>
<td style="text-align: left">ORGANIZATION</td>
<td style="text-align: left">Organizations related to a patient, e.g. company, government organization, and the like, with structures and their own functions.</td>
</tr>
<tr>
<td style="text-align: left">SYMPTOM_AND_DISEASE</td>
<td style="text-align: left">Symptoms that a patient experiences, and diseases that a patient had prior to COVID-19 or complications that usually appear in death reports.</td>
</tr>
<tr>
<td style="text-align: left">TRANSPORTATION</td>
<td style="text-align: left">Means of transportation that a patient used. Here, we only tag the specific identifier of vehicles, e.g. flight numbers and bus/car plates.</td>
</tr>
<tr>
<td style="text-align: left">DATE</td>
<td style="text-align: left">Any date that appears in the sentence.</td>
</tr>
</tbody>
</table>
<p>The dataset was randomly split into train/val/test sets with a ratio of 5/2/3, ensuring comparable distributions of entity types across these three sets. Statistics of the dataset are presented in the table below:</p>
<table>
<thead>
<tr>
<th style="text-align: left">Entity Type</th>
<th style="text-align: right">Train</th>
<th style="text-align: right">Val</th>
<th style="text-align: right">Test</th>
<th style="text-align: right">All</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">PATIENT_ID</td>
<td style="text-align: right"> 3240</td>
<td style="text-align: right"> 1276</td>
<td style="text-align: right"> 2005</td>
<td style="text-align: right"> 6521</td>
</tr>
<tr>
<td style="text-align: left">NAME</td>
<td style="text-align: right"> 349</td>
<td style="text-align: right"> 188</td>
<td style="text-align: right"> 318</td>
<td style="text-align: right"> 855</td>
</tr>
<tr>
<td style="text-align: left">AGE</td>
<td style="text-align: right"> 682</td>
<td style="text-align: right"> 361</td>
<td style="text-align: right"> 582</td>
<td style="text-align: right"> 1625</td>
</tr>
<tr>
<td style="text-align: left">GENDER</td>
<td style="text-align: right"> 542</td>
<td style="text-align: right"> 277</td>
<td style="text-align: right"> 462</td>
<td style="text-align: right"> 1281</td>
</tr>
<tr>
<td style="text-align: left">JOB</td>
<td style="text-align: right"> 205</td>
<td style="text-align: right"> 132</td>
<td style="text-align: right"> 173</td>
<td style="text-align: right"> 510</td>
</tr>
<tr>
<td style="text-align: left">LOCATION</td>
<td style="text-align: right"> 5398</td>
<td style="text-align: right"> 2737</td>
<td style="text-align: right"> 4441</td>
<td style="text-align: right"> 12576</td>
</tr>
<tr>
<td style="text-align: left">ORGANIZATION</td>
<td style="text-align: right"> 1137</td>
<td style="text-align: right"> 551</td>
<td style="text-align: right"> 771</td>
<td style="text-align: right"> 2459</td>
</tr>
<tr>
<td style="text-align: left">SYMPTOM_AND_DISEASE</td>
<td style="text-align: right"> 1439</td>
<td style="text-align: right"> 766</td>
<td style="text-align: right"> 1136</td>
<td style="text-align: right"> 3341</td>
</tr>
<tr>
<td style="text-align: left">TRANSPORTATION</td>
<td style="text-align: right"> 226</td>
<td style="text-align: right"> 87</td>
<td style="text-align: right"> 193</td>
<td style="text-align: right"> 506</td>
</tr>
<tr>
<td style="text-align: left">DATE</td>
<td style="text-align: right"> 2549</td>
<td style="text-align: right"> 1103</td>
<td style="text-align: right"> 1654</td>
<td style="text-align: right"> 5306</td>
</tr>
<tr>
<td style="text-align: left"># Entities in total</td>
<td style="text-align: right"> 15767</td>
<td style="text-align: right"> 7478</td>
<td style="text-align: right"> 11735</td>
<td style="text-align: right"> 34984</td>
</tr>
<tr>
<td style="text-align: left"># Sentences in total</td>
<td style="text-align: right"> 5027</td>
<td style="text-align: right"> 2000</td>
<td style="text-align: right"> 3000</td>
<td style="text-align: right"> 10027</td>
</tr>
</tbody>
</table>
<p>If you are building your own dataset, <a href="https://prodi.gy/">Prodigy</a> is a great annotation tool for a NER task.</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://arxiv.org/abs/2104.03879v1">[1] COVID-19 Named Entity Recognition for Vietnamese</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page. In the last half-decade, Natural Language Processing (NLP) applications appear more and more in industrial products or business processes, reaching the same popularity as Computer Vision. Therefore, this field is too big to ignore. In this post, we will first time talk about an important NLP problem, Named Entity Recognition (NER) which is the task of tagging entities in text with their corresponding type. In addition, we will also build a simple AI-core web application using Gradio and Hugging Face Spaces.Easy Object Detection2022-11-19T00:00:00+00:002022-11-19T00:00:00+00:00https://gather-ai.github.io/tutorials/ezdet<p>👋 Hi there. Welcome back to my page. In the last 2 tutorial series on <a href="https://gather-ai.github.io/tutorials/domain-generalization-part-1/">Domain Generalization</a> and <a href="https://gather-ai.github.io/tutorials/federated-learning-iot-part-1/">Federated Learning on IoT Devices</a>, we dealt with 2 different types of classification (ECG classification and image classification), the most fundamental (and simple) task in Machine Learning (ML). Today, we will explore how to easily move from classification to object detection, a more advanced task in ML and Computer Vision (CV). Let’s get started.</p>
<h2 id="1-background">1. Background</h2>
<h3 id="motivation">Motivation</h3>
<p>I start this tutorial series and an open-source repository <a href="https://github.com/lhkhiem28/ezdet">ezdet</a> by 3 observations:</p>
<ul>
<li>When people begin to learn ML, specifically CV, they typically begin with an image classification tutorial, such as <a href="https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html">PyTorch’s one</a>. After that, they usually move to the object detection problem next, where the difficulty occurs. In particular, the tutorial on object detection in the community is not good and dissimilar from the one on classification. Therefore, people come to some open-source repositories like <a href="https://github.com/ultralytics/yolov5">YOLO</a> from Ultralytics or <a href="https://github.com/facebookresearch/detectron2">Detectron</a>. Although these repositories are powerful, they still have their own drawbacks.</li>
<li>Available open-source repositories are very complex and equipped with many advanced techniques. This makes them not good starting points for people who just come to the term and want to understand object detection in a similar way as classification. Moreover, equipping many add-on techniques makes it difficult to fairly compare object detection models to each other.</li>
<li>These wonderful repositories are designed in a way that is not flexible enough for engineers and researchers to integrate object detection models into other ML projects.</li>
</ul>
<p>With the above observations, I created <a href="https://github.com/lhkhiem28/ezdet">ezdet</a> to overcome these issues. Firstly, the ezdet’s source code is organized in a similar way to the classification problem. Secondly, ezdet decouples the standard object detection process with other add-on techniques. Finally, ezdet can be easily integrated into other ML projects. At the end of this tutorial, we will embed ezdet into a Federated Learning project.</p>
<h3 id="object-detection">Object Detection</h3>
<p>Object detection is an advanced task in the CV field that deals with the localization and classification of objects contained in an image or video. For easy understanding, let’s distinguish between image classification and object detection. Image classification sends a whole image through a classifier for it to spit out a tag. Classifiers take into consideration the whole image but don’t tell you where the tag appears in the image. Object detection is slightly more sophisticated, as it creates a bounding box around the classified object. Figure 1 illustrates this distinction.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/ezdet/classification-vs-detection.jpg" />
<figcaption>Figure 1. Image Classification vs. Object Detection. Mount from [1]</figcaption>
</figure>
<p>From an ML perspective, object detection is a multi-task learning problem, the term that we discussed in a <a href="https://gather-ai.github.io/tutorials/domain-generalization-part-2/">previous article</a>. Specifically, the detectors are trained with a joined (simplified) objective function as below:</p>
\[\mathcal{L}_{total} = \lambda_{loc}\mathcal{L}_{loc}(\widehat{b}, b) + \lambda_{cls}\mathcal{L}_{cls}(\widehat{y}, y)\]
<p>where $\widehat{b}$ and $b$ are predicted and ground truth bounding box coordinates, usually in the form of ($x_{min}$, $y_{min}$, $x_{max}$, $y_{max}$) or ($x_{center}$, $y_{center}$, $width$, $height$); $\widehat{y}$ and $y$ are the predicted probability and ground truth category of the object in that bounding box. \(\mathcal{L}_{loc}\) can be a simple IoU function, \(\mathcal{L}_{cls}\) can be a cross-entropy loss function; $\lambda_{loc}$ and $\lambda_{cls}$ are control hyper-parameters to balance these two loss terms.</p>
<p>Object detection models typically can be categorized into 2 groups:</p>
<ul>
<li>Two-stage (Proposal-based) detectors: The two stages of a two-stage detector can be divided by an RoI (Region of Interest) Pooling layer. One of the prominent two-stage object detectors is <a href="https://arxiv.org/abs/1506.01497">Faster R-CNN</a>. It has the first stage called RPN, a Region Proposal Network to predict candidate bounding boxes. In the second stage, features are by RoI pooling operation from each candidate box for the following classification and bounding box regression tasks.</li>
<li>One-stage (Proposal-free) detectors: In contrast, a one-stage detector predicts bounding boxes in a single step without using region proposals. It leverages the help of a grid box and anchors to localize the region of detection in the image and constraint the shape of the object. In this tutorial, I will use a <a href="https://arxiv.org/abs/1804.02767">YOLOv3</a> model, a popular one-stage detector, with the <a href="https://github.com/eriklindernoren/PyTorch-YOLOv3">API</a> implemented in PyTorch for demonstration, other architecture will be developed in the future.</li>
</ul>
<h2 id="2-building-an-object-detection-pipeline">2. Building an Object Detection Pipeline</h2>
<p>In this section, we will walk through an end-to-end object detection pipeline, which is built in a similar way to any PyTorch classification pipeline.</p>
<h3 id="voc2007-dataset">VOC2007 Dataset</h3>
<p>The VOC2007 dataset will be used, which contains 5011 images and 4952 images in the training set and test set, respectively. You can download the dataset from <a href="https://pjreddie.com/projects/pascal-voc-dataset-mirror/">here</a>. After downloading, you should organize the dataset as the below structure and then convert the provided labels to YOLO format.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>│
├───datasets
│ └───VOC2007
│ ├───train
│ │ ├───images
│ │ │ 2007_000005.jpg
│ │ │ ...
│ │ └───labels
│ │ 2007_000005.xml
│ │ ...
│ └───val
│ ├───images
│ │ 2007_000001.jpg
│ │ ...
│ └───labels
│ 2007_000001.xml
| ...
├───source
│ └───*.py
</code></pre></div></div>
<p>After organizing the dataset, as any PyTorch classification pipeline, we need to write a <code class="language-plaintext highlighter-rouge">Dataset</code> class with a <code class="language-plaintext highlighter-rouge">__getitem__</code> function to return a pair of an image and its label (bounding boxes and classes). Snippet 1 illustrates the implementation.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 1: Dataset class.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="k">class</span> <span class="nc">DetImageDataset</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">Dataset</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">images_path</span><span class="p">,</span> <span class="n">labels_path</span>
<span class="p">,</span> <span class="n">image_size</span> <span class="o">=</span> <span class="mi">416</span>
<span class="p">,</span> <span class="n">augment</span> <span class="o">=</span> <span class="bp">False</span>
<span class="p">,</span> <span class="n">multiscale</span> <span class="o">=</span> <span class="bp">False</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">image_files</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">label_files</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">.</span><span class="n">glob</span><span class="p">(</span><span class="n">images_path</span> <span class="o">+</span> <span class="s">"/*"</span><span class="p">)),</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">.</span><span class="n">glob</span><span class="p">(</span><span class="n">labels_path</span> <span class="o">+</span> <span class="s">"/*"</span><span class="p">))</span>
<span class="bp">self</span><span class="p">.</span><span class="n">image_size</span> <span class="o">=</span> <span class="n">image_size</span>
<span class="bp">self</span><span class="p">.</span><span class="n">augment</span> <span class="o">=</span> <span class="n">augment</span>
<span class="bp">self</span><span class="p">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">A</span><span class="p">.</span><span class="n">Compose</span><span class="p">(</span>
<span class="p">[</span>
<span class="n">A</span><span class="p">.</span><span class="n">HorizontalFlip</span><span class="p">(</span>
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="p">),</span>
<span class="n">A</span><span class="p">.</span><span class="n">BBoxSafeRandomCrop</span><span class="p">(</span>
<span class="n">erosion_rate</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span>
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="p">),</span>
<span class="n">A</span><span class="p">.</span><span class="n">RandomBrightnessContrast</span><span class="p">(</span>
<span class="n">brightness_limit</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span> <span class="n">contrast_limit</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">,</span>
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.3</span><span class="p">,</span>
<span class="p">),</span>
<span class="n">A</span><span class="p">.</span><span class="n">RGBShift</span><span class="p">(</span>
<span class="n">r_shift_limit</span> <span class="o">=</span> <span class="mi">30</span><span class="p">,</span> <span class="n">g_shift_limit</span> <span class="o">=</span> <span class="mi">30</span><span class="p">,</span> <span class="n">b_shift_limit</span> <span class="o">=</span> <span class="mi">30</span><span class="p">,</span>
<span class="n">p</span> <span class="o">=</span> <span class="mf">0.3</span><span class="p">,</span>
<span class="p">),</span>
<span class="p">],</span>
<span class="n">A</span><span class="p">.</span><span class="n">BboxParams</span><span class="p">(</span><span class="s">"yolo"</span><span class="p">,</span> <span class="p">[</span><span class="s">"classes"</span><span class="p">])</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">multiscale</span> <span class="o">=</span> <span class="n">multiscale</span>
<span class="bp">self</span><span class="p">.</span><span class="n">image_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="p">.</span><span class="n">image_size</span> <span class="o">+</span> <span class="mi">32</span><span class="o">*</span><span class="n">scale</span> <span class="k">for</span> <span class="n">scale</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]</span>
<span class="k">def</span> <span class="nf">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="p">):</span>
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">image_files</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">square_pad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">image</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">_</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">shape</span>
<span class="n">gap_pad</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">abs</span><span class="p">(</span><span class="n">h</span> <span class="o">-</span> <span class="n">w</span><span class="p">)</span>
<span class="k">if</span> <span class="n">h</span> <span class="o">-</span> <span class="n">w</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">pad</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">gap_pad</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">gap_pad</span> <span class="o">-</span> <span class="n">gap_pad</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">pad</span> <span class="o">=</span> <span class="p">(</span><span class="n">gap_pad</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">gap_pad</span> <span class="o">-</span> <span class="n">gap_pad</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">pad</span><span class="p">(</span>
<span class="n">image</span><span class="p">,</span>
<span class="n">pad</span> <span class="o">=</span> <span class="n">pad</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">index</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">image_file</span><span class="p">,</span> <span class="n">label_file</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">image_files</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">label_files</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">cv2</span><span class="p">.</span><span class="n">imread</span><span class="p">(</span><span class="n">image_file</span><span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">cv2</span><span class="p">.</span><span class="n">cvtColor</span><span class="p">(</span>
<span class="n">image</span><span class="p">,</span>
<span class="n">code</span> <span class="o">=</span> <span class="n">cv2</span><span class="p">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">loadtxt</span><span class="p">(</span><span class="n">label_file</span><span class="p">)</span>
<span class="n">bboxes</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">augment</span><span class="p">:</span>
<span class="n">Transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">transforms</span><span class="p">(</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="p">,</span>
<span class="n">classes</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">bboxes</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:]</span>
<span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">Transformed</span><span class="p">[</span><span class="s">"image"</span><span class="p">]</span>
<span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">Transformed</span><span class="p">[</span><span class="s">"bboxes"</span><span class="p">])</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">shape</span>
<span class="n">image</span><span class="p">,</span> <span class="n">pad</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">square_pad</span><span class="p">(</span><span class="n">image</span><span class="p">);</span> <span class="n">_</span><span class="p">,</span> <span class="n">padded_h</span><span class="p">,</span> <span class="n">padded_w</span> <span class="o">=</span> <span class="n">image</span><span class="p">.</span><span class="n">shape</span>
<span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="o">=</span> <span class="n">w</span><span class="o">*</span><span class="p">(</span><span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="n">pad</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">w</span><span class="o">*</span><span class="p">(</span><span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="n">pad</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">h</span><span class="o">*</span><span class="p">(</span><span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="n">pad</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">h</span><span class="o">*</span><span class="p">(</span><span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="n">pad</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span>
<span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="o">=</span> <span class="p">((</span><span class="n">c1</span> <span class="o">+</span> <span class="n">c2</span><span class="p">)</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span><span class="o">/</span><span class="n">padded_w</span><span class="p">,</span> <span class="p">((</span><span class="n">c3</span> <span class="o">+</span> <span class="n">c4</span><span class="p">)</span><span class="o">/</span><span class="mi">2</span><span class="p">)</span><span class="o">/</span><span class="n">padded_h</span><span class="p">,</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span><span class="o">*</span><span class="p">(</span><span class="n">w</span><span class="o">/</span><span class="n">padded_w</span><span class="p">),</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span><span class="o">*</span><span class="p">(</span><span class="n">h</span><span class="o">/</span><span class="n">padded_h</span><span class="p">),</span>
<span class="k">return</span> <span class="n">image</span><span class="p">.</span><span class="nb">float</span><span class="p">(),</span> <span class="n">F</span><span class="p">.</span><span class="n">pad</span><span class="p">(</span>
<span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">bboxes</span><span class="p">),</span>
<span class="n">pad</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">value</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">batch</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">multiscale</span> <span class="ow">and</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">random</span><span class="p">()</span> <span class="o"><=</span> <span class="mf">0.1</span><span class="p">:</span>
<span class="bp">self</span><span class="p">.</span><span class="n">image_size</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">image_sizes</span><span class="p">)</span>
<span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span>
<span class="n">F</span><span class="p">.</span><span class="n">interpolate</span><span class="p">(</span>
<span class="n">image</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">image_size</span><span class="p">,</span> <span class="n">mode</span> <span class="o">=</span> <span class="s">"nearest"</span><span class="p">,</span>
<span class="p">).</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">images</span>
<span class="p">])</span>
<span class="n">images</span> <span class="o">=</span> <span class="n">images</span><span class="o">/</span><span class="mi">255</span>
<span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">bboxes</span> <span class="k">for</span> <span class="n">bboxes</span> <span class="ow">in</span> <span class="n">labels</span> <span class="k">if</span> <span class="n">bboxes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">]</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">bboxes</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">labels</span><span class="p">):</span>
<span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">index</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span><span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
<span class="k">return</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span>
</code></pre></div></div>
<p>The component that I want you to notice in the above implementation is <code class="language-plaintext highlighter-rouge">self.transforms</code>, which is used for performing data augmentation. Here, I use the <a href="https://albumentations.ai/">Albumentations</a> library, you can modify the <code class="language-plaintext highlighter-rouge">self.transforms</code> attribute to use the data augmentation strategy that you want depending on your problem. You can also change the image size or multi-scale training strategy easily.</p>
<h3 id="config-the-model">Config the model</h3>
<p>The next step is to config the YOLO model and set the training hyper-parameters. This is an easy step. Based on the provided <a href="https://github.com/pjreddie/darknet/blob/master/cfg/yolov3-voc.cfg"><code class="language-plaintext highlighter-rouge">yolov3.cfg</code></a>, we just need to change some hyper-parameters we will use. For example, I changed the number of epochs, learning rate, and weight decay as below:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>num_epochs=250
lr=0.0001
weight_decay=0.0005
</code></pre></div></div>
<h3 id="a-training-function">A Training Function</h3>
<p>Next, in any PyTorch classification pipeline, we need a training function. The implementation of this function in Snippet 2 added a feature that returns loss and mAP from training and evaluation at each epoch. This feature is usually ignored in many open-source repositories.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 2: Training function.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="k">def</span> <span class="nf">train_fn</span><span class="p">(</span>
<span class="n">train_loaders</span><span class="p">,</span>
<span class="n">model</span><span class="p">,</span>
<span class="n">num_epochs</span><span class="p">,</span>
<span class="n">optimizer</span><span class="p">,</span>
<span class="n">lr_scheduler</span><span class="p">,</span>
<span class="n">save_ckp_dir</span> <span class="o">=</span> <span class="s">"./"</span><span class="p">,</span>
<span class="n">training_verbose</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span>
<span class="p">):</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Start Training ...</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="s">" = "</span><span class="o">*</span><span class="mi">16</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">best_map</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">.</span><span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">disable</span> <span class="o">=</span> <span class="n">training_verbose</span><span class="p">):</span>
<span class="k">if</span> <span class="n">training_verbose</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"epoch {:2}/{:2}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">)</span> <span class="o">+</span> <span class="s">"</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="s">" - "</span><span class="o">*</span><span class="mi">16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">epoch</span> <span class="o"><=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.08</span><span class="o">*</span><span class="n">num_epochs</span><span class="p">):</span>
<span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">param_groups</span><span class="p">:</span>
<span class="n">param_group</span><span class="p">[</span><span class="s">"lr"</span><span class="p">]</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">hyperparams</span><span class="p">[</span><span class="s">"lr"</span><span class="p">]</span><span class="o">*</span><span class="n">epoch</span><span class="o">/</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="mf">0.08</span><span class="o">*</span><span class="n">num_epochs</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">lr_scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">(</span>
<span class="p">{</span><span class="s">"lr"</span><span class="p">:</span><span class="n">optimizer</span><span class="p">.</span><span class="n">param_groups</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s">"lr"</span><span class="p">]},</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">epoch</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="n">running_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">for</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">.</span><span class="n">tqdm</span><span class="p">(</span><span class="n">train_loaders</span><span class="p">[</span><span class="s">"train"</span><span class="p">],</span> <span class="n">disable</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">training_verbose</span><span class="p">):</span>
<span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">images</span><span class="p">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">labels</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">compute_loss</span><span class="p">(</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span>
<span class="n">model</span><span class="p">,</span>
<span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">(),</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">running_loss</span> <span class="o">=</span> <span class="n">running_loss</span> <span class="o">+</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span><span class="o">*</span><span class="n">images</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">train_loss</span> <span class="o">=</span> <span class="n">running_loss</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">train_loaders</span><span class="p">[</span><span class="s">"train"</span><span class="p">].</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">(</span>
<span class="p">{</span><span class="s">"train_loss"</span><span class="p">:</span><span class="n">train_loss</span><span class="p">},</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">epoch</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">training_verbose</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"train - loss:{:.4f}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">train_loss</span><span class="p">))</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">running_classes</span><span class="p">,</span> <span class="n">running_statistics</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">.</span><span class="n">tqdm</span><span class="p">(</span><span class="n">train_loaders</span><span class="p">[</span><span class="s">"val"</span><span class="p">],</span> <span class="n">disable</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">training_verbose</span><span class="p">):</span>
<span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">images</span><span class="p">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">labels</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">=</span> <span class="n">xywh2xyxy</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span>
<span class="n">labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span><span class="o">*</span><span class="nb">int</span><span class="p">(</span><span class="n">train_loaders</span><span class="p">[</span><span class="s">"val"</span><span class="p">].</span><span class="n">dataset</span><span class="p">.</span><span class="n">image_size</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">non_max_suppression</span><span class="p">(</span>
<span class="n">logits</span><span class="p">,</span>
<span class="n">conf_thres</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">iou_thres</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">running_classes</span><span class="p">,</span> <span class="n">running_statistics</span> <span class="o">=</span> <span class="n">running_classes</span> <span class="o">+</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">].</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">running_statistics</span> <span class="o">+</span> <span class="n">get_batch_statistics</span><span class="p">(</span>
<span class="p">[</span><span class="n">logit</span><span class="p">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">for</span> <span class="n">logit</span> <span class="ow">in</span> <span class="n">logits</span><span class="p">],</span> <span class="n">labels</span><span class="p">.</span><span class="n">cpu</span><span class="p">(),</span>
<span class="mf">0.5</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">val_map</span> <span class="o">=</span> <span class="n">ap_per_class</span><span class="p">(</span>
<span class="o">*</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">stats</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">stats</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">running_statistics</span><span class="p">))],</span>
<span class="n">running_classes</span><span class="p">,</span>
<span class="p">)[</span><span class="mi">2</span><span class="p">].</span><span class="n">mean</span><span class="p">()</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">(</span>
<span class="p">{</span><span class="s">"val_map"</span><span class="p">:</span><span class="n">val_map</span><span class="p">},</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">epoch</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">training_verbose</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"val - map:{:.4f}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">val_map</span><span class="p">))</span>
<span class="k">if</span> <span class="n">best_map</span> <span class="o"><</span> <span class="n">val_map</span><span class="p">:</span>
<span class="n">best_map</span> <span class="o">=</span> <span class="n">val_map</span><span class="p">;</span> <span class="n">torch</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s">"{}/yolov3.ptl"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">save_ckp_dir</span><span class="p">))</span>
</code></pre></div></div>
<p>Let’s notice at arguments that the above function receives, <code class="language-plaintext highlighter-rouge">optimizer</code> and <code class="language-plaintext highlighter-rouge">lr_scheduler</code> in particular. You can create any <code class="language-plaintext highlighter-rouge">optimizer</code> such as SGD or Adam, any <code class="language-plaintext highlighter-rouge">lr_scheduler</code> such as StepLR or CosineAnnealingLR, and pass them into the function, then the function will do all the rest.</p>
<h3 id="start-training">Start training</h3>
<p>Now, we are ready to train our YOLO model. Firstly, let’s initialize PyTorch data loaders:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 3: Data Loaders.
"""</span>
<span class="n">datasets</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">"train"</span><span class="p">:</span><span class="n">DetImageDataset</span><span class="p">(</span>
<span class="n">images_path</span> <span class="o">=</span> <span class="s">"../datasets/{}/train/images"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">),</span> <span class="n">labels_path</span> <span class="o">=</span> <span class="s">"../datasets/{}/train/labels"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="p">,</span> <span class="n">image_size</span> <span class="o">=</span> <span class="mi">416</span>
<span class="p">,</span> <span class="n">augment</span> <span class="o">=</span> <span class="bp">True</span>
<span class="p">,</span> <span class="n">multiscale</span> <span class="o">=</span> <span class="bp">True</span>
<span class="p">),</span>
<span class="s">"val"</span><span class="p">:</span><span class="n">DetImageDataset</span><span class="p">(</span>
<span class="n">images_path</span> <span class="o">=</span> <span class="s">"../datasets/{}/val/images"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">),</span> <span class="n">labels_path</span> <span class="o">=</span> <span class="s">"../datasets/{}/val/labels"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="p">,</span> <span class="n">image_size</span> <span class="o">=</span> <span class="mi">416</span>
<span class="p">,</span> <span class="n">augment</span> <span class="o">=</span> <span class="bp">False</span>
<span class="p">,</span> <span class="n">multiscale</span> <span class="o">=</span> <span class="bp">False</span>
<span class="p">),</span>
<span class="p">}</span>
<span class="n">train_loaders</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">"train"</span><span class="p">:</span><span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="p">[</span><span class="s">"train"</span><span class="p">],</span> <span class="n">collate_fn</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">[</span><span class="s">"train"</span><span class="p">].</span><span class="n">collate_fn</span><span class="p">,</span>
<span class="n">num_workers</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">shuffle</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span>
<span class="p">),</span>
<span class="s">"val"</span><span class="p">:</span><span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">datasets</span><span class="p">[</span><span class="s">"val"</span><span class="p">],</span> <span class="n">collate_fn</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">[</span><span class="s">"val"</span><span class="p">].</span><span class="n">collate_fn</span><span class="p">,</span>
<span class="n">num_workers</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">shuffle</span> <span class="o">=</span> <span class="bp">False</span><span class="p">,</span>
<span class="p">),</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Next, we will initialize a YOLO model, load pre-trained backbone weighs, and create an <code class="language-plaintext highlighter-rouge">optimizer</code> and a <code class="language-plaintext highlighter-rouge">lr_scheduler</code>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 4: Model Initialization.
"""</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Darknet</span><span class="p">(</span><span class="s">"pytorchyolo/configs/yolov3.cfg"</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="n">load_darknet_weights</span><span class="p">(</span><span class="s">"../ckps/darknet53.conv.74"</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span>
<span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">hyperparams</span><span class="p">[</span><span class="s">"lr"</span><span class="p">],</span> <span class="n">weight_decay</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">hyperparams</span><span class="p">[</span><span class="s">"weight_decay"</span><span class="p">],</span>
<span class="p">)</span>
<span class="n">lr_scheduler</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">lr_scheduler</span><span class="p">.</span><span class="n">CosineAnnealingLR</span><span class="p">(</span>
<span class="n">optimizer</span><span class="p">,</span>
<span class="n">eta_min</span> <span class="o">=</span> <span class="mf">0.01</span><span class="o">*</span><span class="n">model</span><span class="p">.</span><span class="n">hyperparams</span><span class="p">[</span><span class="s">"lr"</span><span class="p">],</span> <span class="n">T_max</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.92</span><span class="o">*</span><span class="nb">int</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">hyperparams</span><span class="p">[</span><span class="s">"num_epochs"</span><span class="p">])),</span>
<span class="p">)</span>
</code></pre></div></div>
<p>Finally, start training. Don’t forget to use wandb to log the results:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 5: Training.
"""</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">login</span><span class="p">()</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">init</span><span class="p">(</span>
<span class="n">project</span> <span class="o">=</span> <span class="s">"ezdet"</span><span class="p">,</span> <span class="n">name</span> <span class="o">=</span> <span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">save_ckp_dir</span> <span class="o">=</span> <span class="s">"../ckps/{}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">save_ckp_dir</span><span class="p">):</span>
<span class="n">os</span><span class="p">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">save_ckp_dir</span><span class="p">)</span>
<span class="n">train_fn</span><span class="p">(</span>
<span class="n">train_loaders</span><span class="p">,</span>
<span class="n">model</span><span class="p">,</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">hyperparams</span><span class="p">[</span><span class="s">"num_epochs"</span><span class="p">]),</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">,</span>
<span class="n">lr_scheduler</span> <span class="o">=</span> <span class="n">lr_scheduler</span><span class="p">,</span>
<span class="n">save_ckp_dir</span> <span class="o">=</span> <span class="n">save_ckp_dir</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">finish</span><span class="p">()</span>
</code></pre></div></div>
<h2 id="3-results">3. Results</h2>
<p>If you carefully follow this tutorial, the results will be like this:</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/ezdet/metrics.jpg" />
<figcaption>Figure 2. Training Loss and Validation mAP. </figcaption>
</figure>
<h2 id="4-integrating-ezdet-into-other-ml-projects">4. Integrating ezdet into other ML projects</h2>
<p>As mentioned above, the ezdet repository can be easily integrated into other ML projects. To demonstrate that, I have used ezdet and <a href="https://flower.dev/">Flower</a> to train YOLOv3 models in a Federated Learning setting. Let’s check the full implementation <a href="https://github.com/lhkhiem28/FedDet">here</a>.</p>
<p>Stay tuned for more content …</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://www.v7labs.com/blog/object-detection-guide">[1] The Ultimate Guide to Object Detection</a><br />
<a href="https://www.researchgate.net/publication/349297260_Bibliometric_Analysis_of_One-stage_and_Two-stage_Object_Detection">[2] Bibliometric Analysis of One-stage and Two-stage Object Detection</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page. In the last 2 tutorial series on Domain Generalization and Federated Learning on IoT Devices, we dealt with 2 different types of classification (ECG classification and image classification), the most fundamental (and simple) task in Machine Learning (ML). Today, we will explore how to easily move from classification to object detection, a more advanced task in ML and Computer Vision (CV). Let’s get started.Federated Learning on IoT Devices - Part 22022-10-29T00:00:00+00:002022-10-29T00:00:00+00:00https://gather-ai.github.io/tutorials/federated-learning-iot-part-2<p>👋 Hi there. Welcome back to my page, this is part 2 of my tutorial series on deploying Federated Learning on IoT devices. In the <a href="https://gather-ai.github.io/tutorials/federated-learning-iot-part-1/">last article</a>, we discussed what FL is and built a network of IoT devices as well as environments for starting work. Today, I will guide you step by step to train a simple CNN model on the CIFAR10 dataset in real IoT devices by using <a href="https://flower.dev/">Flower</a>. Let’s get started.</p>
<h2 id="1-preparing-dataset">1. Preparing Dataset</h2>
<h3 id="cifar10-dataset">CIFAR10 Dataset</h3>
<p>The CIFAR10 dataset consists of 60000 32x32 color images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. Here are the classes in the dataset, as well as 10 random images from each:</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/federated-learning-iot/cifar10.jpg" />
<figcaption>Figure 1. CIFAR10 Dataset. Mount from [1]</figcaption>
</figure>
<h3 id="data-partitioning">Data Partitioning</h3>
<p>In this tutorial, the training data are assigned to the clients in an IID setting. As mentioned before, our network has 10 clients in total, the training data is shuffled and uniformly divided into 10 partitions, each with 5000 images for each client. Note that each partition might be doesn’t include 500 images for each class.</p>
<p>After assigning data to clients, let’s implement a Dataset class, which will be used in a PyTorch DataLoader.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 1: Dataset class.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="k">class</span> <span class="nc">ImageDataset</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">Dataset</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">df</span><span class="p">,</span> <span class="n">data_path</span><span class="p">,</span>
<span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">df</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">data_path</span><span class="p">,</span> <span class="o">=</span> <span class="n">df</span><span class="p">,</span> <span class="n">data_path</span><span class="p">,</span>
<span class="bp">self</span><span class="p">.</span><span class="n">image_size</span> <span class="o">=</span> <span class="n">image_size</span>
<span class="k">def</span> <span class="nf">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="p">):</span>
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">df</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">index</span>
<span class="p">):</span>
<span class="n">row</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">df</span><span class="p">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"{}/{}.npy"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">data_path</span><span class="p">,</span> <span class="n">row</span><span class="p">[</span><span class="s">"id"</span><span class="p">]))</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">cv2</span><span class="p">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">image_size</span><span class="p">)</span><span class="o">/</span><span class="mi">255</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><</span> <span class="mi">3</span><span class="p">:</span>
<span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">image</span><span class="p">).</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">row</span><span class="p">[</span><span class="s">"label"</span><span class="p">]</span>
</code></pre></div></div>
<h2 id="2-ingredients-for-training">2. Ingredients for Training</h2>
<h3 id="a-simple-cnn-model">A Simple CNN Model</h3>
<p>For simplicity, I use a simple LeNet5 model, a pioneer CNN model, for deployment. Snippet 2 is an implementation of this model.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 2: LeNet5 model.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="k">class</span> <span class="nc">LeNet5</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">in_channels</span><span class="p">,</span> <span class="n">num_classes</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LeNet5</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="mi">6</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="mi">2</span><span class="p">),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="mi">16</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="mi">2</span><span class="p">),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">layer3</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">400</span><span class="p">,</span> <span class="mi">120</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">120</span><span class="p">,</span> <span class="mi">84</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">classifier</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">84</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="nb">input</span>
<span class="p">):</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">layer1</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">layer2</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="nb">input</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">input</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">layer3</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">logit</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">classifier</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="k">return</span> <span class="n">logit</span>
</code></pre></div></div>
<h3 id="a-training-function">A Training Function</h3>
<p>We need a function that each client will use to perform training on their own data. All metrics during training should be logged and returned in a dictionary.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 3: Training function.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="k">def</span> <span class="nf">client_fit_fn</span><span class="p">(</span>
<span class="n">loaders</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">device</span><span class="p">(</span><span class="s">"cpu"</span><span class="p">),</span>
<span class="n">save_ckp_path</span> <span class="o">=</span> <span class="s">"./ckp.ptl"</span><span class="p">,</span> <span class="n">training_verbose</span> <span class="o">=</span> <span class="bp">True</span>
<span class="p">):</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Start Client Training ...</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="s">" = "</span><span class="o">*</span><span class="mi">16</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">criterion</span><span class="p">,</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">(),</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span> <span class="o">=</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">disable</span> <span class="o">=</span> <span class="n">training_verbose</span><span class="p">):</span>
<span class="k">if</span> <span class="n">training_verbose</span><span class="p">:</span><span class="k">print</span><span class="p">(</span><span class="s">"epoch {:2}/{:2}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">)</span> <span class="o">+</span> <span class="s">"</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="s">" - "</span><span class="o">*</span><span class="mi">16</span><span class="p">)</span>
<span class="n">running_loss</span><span class="p">,</span> <span class="n">running_correct</span><span class="p">,</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="k">for</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"fit"</span><span class="p">],</span> <span class="n">disable</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">training_verbose</span><span class="p">):</span>
<span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">images</span><span class="p">.</span><span class="nb">float</span><span class="p">().</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">labels</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">(),</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">running_loss</span><span class="p">,</span> <span class="n">running_correct</span><span class="p">,</span> <span class="o">=</span> <span class="n">running_loss</span> <span class="o">+</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span><span class="o">*</span><span class="n">images</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">running_correct</span> <span class="o">+</span> <span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">logits</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">].</span><span class="n">detach</span><span class="p">().</span><span class="n">cpu</span><span class="p">()</span> <span class="o">==</span> <span class="n">labels</span><span class="p">.</span><span class="n">cpu</span><span class="p">()).</span><span class="nb">sum</span><span class="p">().</span><span class="n">item</span><span class="p">(),</span>
<span class="n">fit_loss</span><span class="p">,</span> <span class="n">fit_accuracy</span><span class="p">,</span> <span class="o">=</span> <span class="n">running_loss</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"fit"</span><span class="p">].</span><span class="n">dataset</span><span class="p">),</span> <span class="n">running_correct</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"fit"</span><span class="p">].</span><span class="n">dataset</span><span class="p">),</span>
<span class="k">if</span> <span class="n">training_verbose</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"{:<5} - loss:{:.4f}, accuracy:{:.4f}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span>
<span class="s">"fit"</span><span class="p">,</span>
<span class="n">fit_loss</span><span class="p">,</span> <span class="n">fit_accuracy</span><span class="p">,</span>
<span class="p">))</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">running_loss</span><span class="p">,</span> <span class="n">running_correct</span><span class="p">,</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="k">for</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"eval"</span><span class="p">],</span> <span class="n">disable</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">training_verbose</span><span class="p">):</span>
<span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">images</span><span class="p">.</span><span class="nb">float</span><span class="p">().</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">labels</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="n">running_loss</span><span class="p">,</span> <span class="n">running_correct</span><span class="p">,</span> <span class="o">=</span> <span class="n">running_loss</span> <span class="o">+</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span><span class="o">*</span><span class="n">images</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">running_correct</span> <span class="o">+</span> <span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">logits</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">].</span><span class="n">detach</span><span class="p">().</span><span class="n">cpu</span><span class="p">()</span> <span class="o">==</span> <span class="n">labels</span><span class="p">.</span><span class="n">cpu</span><span class="p">()).</span><span class="nb">sum</span><span class="p">().</span><span class="n">item</span><span class="p">(),</span>
<span class="n">eval_loss</span><span class="p">,</span> <span class="n">eval_accuracy</span><span class="p">,</span> <span class="o">=</span> <span class="n">running_loss</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"eval"</span><span class="p">].</span><span class="n">dataset</span><span class="p">),</span> <span class="n">running_correct</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"eval"</span><span class="p">].</span><span class="n">dataset</span><span class="p">),</span>
<span class="k">if</span> <span class="n">training_verbose</span><span class="p">:</span>
<span class="k">print</span><span class="p">(</span><span class="s">"{:<5} - loss:{:.4f}, accuracy:{:.4f}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span>
<span class="s">"eval"</span><span class="p">,</span>
<span class="n">eval_loss</span><span class="p">,</span> <span class="n">eval_accuracy</span><span class="p">,</span>
<span class="p">))</span>
<span class="n">torch</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_ckp_path</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Finish Client Training ...</span><span class="se">\n</span><span class="s">"</span> <span class="o">+</span> <span class="s">" = "</span><span class="o">*</span><span class="mi">16</span><span class="p">)</span>
<span class="k">return</span> <span class="p">{</span>
<span class="s">"fit_loss"</span><span class="p">:</span><span class="n">fit_loss</span><span class="p">,</span> <span class="s">"fit_accuracy"</span><span class="p">:</span><span class="n">fit_accuracy</span><span class="p">,</span>
<span class="s">"eval_loss"</span><span class="p">:</span><span class="n">eval_loss</span><span class="p">,</span> <span class="s">"eval_accuracy"</span><span class="p">:</span><span class="n">eval_accuracy</span><span class="p">,</span>
<span class="p">}</span>
</code></pre></div></div>
<h2 id="3-server-site">3. Server Site</h2>
<p>We can use our laptop to work as a server, at each round, the server sent a global model to all clients to perform on-device training. When clients finish their training, they will send their local models back to the server, then the global model is updated by an FL strategy, FedAvg for example, where the server averages all models from clients and start the next round.</p>
<p>We will modify the <code class="language-plaintext highlighter-rouge">FedAvg</code> class of Flower to save the global at each round.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 4: FedAvg strategy.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="k">def</span> <span class="nf">metrics_aggregation_fn</span><span class="p">(</span><span class="n">metrics</span><span class="p">):</span>
<span class="n">fit_losses</span><span class="p">,</span> <span class="n">fit_accuracies</span><span class="p">,</span> <span class="o">=</span> <span class="p">[</span><span class="n">metric</span><span class="p">[</span><span class="s">"fit_loss"</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="n">metrics</span><span class="p">],</span> <span class="p">[</span><span class="n">metric</span><span class="p">[</span><span class="s">"fit_accuracy"</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="n">metrics</span><span class="p">],</span>
<span class="n">eval_losses</span><span class="p">,</span> <span class="n">eval_accuracies</span><span class="p">,</span> <span class="o">=</span> <span class="p">[</span><span class="n">metric</span><span class="p">[</span><span class="s">"eval_loss"</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="n">metrics</span><span class="p">],</span> <span class="p">[</span><span class="n">metric</span><span class="p">[</span><span class="s">"eval_accuracy"</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="n">metrics</span><span class="p">],</span>
<span class="n">aggregated_metrics</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">"fit_loss"</span><span class="p">:</span><span class="nb">sum</span><span class="p">(</span><span class="n">fit_losses</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">fit_losses</span><span class="p">),</span> <span class="s">"fit_accuracy"</span><span class="p">:</span><span class="nb">sum</span><span class="p">(</span><span class="n">fit_accuracies</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">fit_accuracies</span><span class="p">),</span>
<span class="s">"eval_loss"</span><span class="p">:</span><span class="nb">sum</span><span class="p">(</span><span class="n">eval_losses</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">eval_losses</span><span class="p">),</span> <span class="s">"eval_accuracy"</span><span class="p">:</span><span class="nb">sum</span><span class="p">(</span><span class="n">eval_accuracies</span><span class="p">)</span><span class="o">/</span><span class="nb">len</span><span class="p">(</span><span class="n">eval_accuracies</span><span class="p">),</span>
<span class="p">}</span>
<span class="k">return</span> <span class="n">aggregated_metrics</span>
<span class="k">class</span> <span class="nc">FedAvg</span><span class="p">(</span><span class="n">fl</span><span class="p">.</span><span class="n">server</span><span class="p">.</span><span class="n">strategy</span><span class="p">.</span><span class="n">FedAvg</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">initial_model</span><span class="p">,</span>
<span class="n">save_ckp_path</span><span class="p">,</span>
<span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">initial_model</span> <span class="o">=</span> <span class="n">initial_model</span>
<span class="bp">self</span><span class="p">.</span><span class="n">save_ckp_path</span> <span class="o">=</span> <span class="n">save_ckp_path</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">aggregate_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">server_round</span><span class="p">,</span>
<span class="n">results</span><span class="p">,</span> <span class="n">failures</span>
<span class="p">):</span>
<span class="n">aggregated_metrics</span> <span class="o">=</span> <span class="n">metrics_aggregation_fn</span><span class="p">([(</span><span class="n">result</span><span class="p">.</span><span class="n">num_examples</span><span class="p">,</span> <span class="n">result</span><span class="p">.</span><span class="n">metrics</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">result</span> <span class="ow">in</span> <span class="n">results</span><span class="p">])</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">({</span><span class="s">"fit_loss"</span><span class="p">:</span><span class="n">aggregated_metrics</span><span class="p">[</span><span class="s">"fit_loss"</span><span class="p">]},</span> <span class="n">step</span> <span class="o">=</span> <span class="n">server_round</span><span class="p">),</span> <span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">({</span><span class="s">"fit_accuracy"</span><span class="p">:</span><span class="n">aggregated_metrics</span><span class="p">[</span><span class="s">"fit_accuracy"</span><span class="p">]},</span> <span class="n">step</span> <span class="o">=</span> <span class="n">server_round</span><span class="p">),</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">({</span><span class="s">"eval_loss"</span><span class="p">:</span><span class="n">aggregated_metrics</span><span class="p">[</span><span class="s">"eval_loss"</span><span class="p">]},</span> <span class="n">step</span> <span class="o">=</span> <span class="n">server_round</span><span class="p">),</span> <span class="n">wandb</span><span class="p">.</span><span class="n">log</span><span class="p">({</span><span class="s">"eval_accuracy"</span><span class="p">:</span><span class="n">aggregated_metrics</span><span class="p">[</span><span class="s">"eval_accuracy"</span><span class="p">]},</span> <span class="n">step</span> <span class="o">=</span> <span class="n">server_round</span><span class="p">),</span>
<span class="n">aggregated_parameters</span><span class="p">,</span> <span class="n">results</span> <span class="o">=</span> <span class="nb">super</span><span class="p">().</span><span class="n">aggregate_fit</span><span class="p">(</span>
<span class="n">server_round</span><span class="p">,</span>
<span class="n">results</span><span class="p">,</span> <span class="n">failures</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">aggregated_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="bp">self</span><span class="p">.</span><span class="n">initial_model</span><span class="p">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">OrderedDict</span><span class="p">({</span><span class="n">key</span><span class="p">:</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">initial_model</span><span class="p">.</span><span class="n">state_dict</span><span class="p">().</span><span class="n">keys</span><span class="p">(),</span> <span class="n">fl</span><span class="p">.</span><span class="n">common</span><span class="p">.</span><span class="n">parameters_to_weights</span><span class="p">(</span><span class="n">aggregated_parameters</span><span class="p">))}),</span> <span class="n">strict</span> <span class="o">=</span> <span class="bp">True</span><span class="p">)</span>
<span class="n">torch</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">initial_model</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">save_ckp_path</span><span class="p">)</span>
<span class="k">return</span> <span class="n">aggregated_parameters</span><span class="p">,</span> <span class="p">{}</span>
</code></pre></div></div>
<p>The server can be easily started by passing your laptop IP address and an arbitrary port into the <code class="language-plaintext highlighter-rouge">start_server</code> function.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 5: Server site.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">from</span> <span class="nn">data</span> <span class="kn">import</span> <span class="n">ImageDataset</span>
<span class="kn">from</span> <span class="nn">nets</span> <span class="kn">import</span> <span class="n">LeNet5</span>
<span class="kn">from</span> <span class="nn">strategies</span> <span class="kn">import</span> <span class="n">FedAvg</span>
<span class="kn">from</span> <span class="nn">engines</span> <span class="kn">import</span> <span class="n">server_test_fn</span>
<span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="p">.</span><span class="n">ArgumentParser</span><span class="p">()</span>
<span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--server_address"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span> <span class="o">=</span> <span class="s">"192.168.50.102"</span><span class="p">),</span> <span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--server_port"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">int</span><span class="p">)</span>
<span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--dataset"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span> <span class="o">=</span> <span class="s">"CIFAR10"</span><span class="p">),</span> <span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--num_clients"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">int</span><span class="p">,</span> <span class="n">default</span> <span class="o">=</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--num_rounds"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">int</span><span class="p">,</span> <span class="n">default</span> <span class="o">=</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="p">.</span><span class="n">parse_args</span><span class="p">()</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">login</span><span class="p">()</span>
<span class="n">wandb</span><span class="p">.</span><span class="n">init</span><span class="p">(</span><span class="n">project</span> <span class="o">=</span> <span class="s">"FL-IoT"</span><span class="p">,</span> <span class="n">name</span> <span class="o">=</span> <span class="s">"{}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">))</span>
<span class="n">initial_model</span> <span class="o">=</span> <span class="n">LeNet5</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="s">"MNIST"</span> <span class="ow">in</span> <span class="n">args</span><span class="p">.</span><span class="n">dataset</span> <span class="k">else</span> <span class="mi">3</span><span class="p">,</span> <span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">initial_parameters</span> <span class="o">=</span> <span class="p">[</span><span class="n">value</span><span class="p">.</span><span class="n">cpu</span><span class="p">().</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">initial_model</span><span class="p">.</span><span class="n">state_dict</span><span class="p">().</span><span class="n">items</span><span class="p">()]</span>
<span class="n">save_ckp_path</span> <span class="o">=</span> <span class="s">"../ckps/{}/server.ptl"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="s">"/"</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_ckp_path</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"/"</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])):</span>
<span class="n">os</span><span class="p">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s">"/"</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_ckp_path</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"/"</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">fl</span><span class="p">.</span><span class="n">server</span><span class="p">.</span><span class="n">start_server</span><span class="p">(</span>
<span class="n">server_address</span> <span class="o">=</span> <span class="s">"{}:{}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">server_address</span><span class="p">,</span> <span class="n">args</span><span class="p">.</span><span class="n">server_port</span><span class="p">),</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">{</span><span class="s">"num_rounds"</span><span class="p">:</span><span class="n">args</span><span class="p">.</span><span class="n">num_rounds</span><span class="p">},</span>
<span class="n">strategy</span> <span class="o">=</span> <span class="n">FedAvg</span><span class="p">(</span><span class="n">min_available_clients</span> <span class="o">=</span> <span class="n">args</span><span class="p">.</span><span class="n">num_clients</span><span class="p">,</span>
<span class="n">min_fit_clients</span> <span class="o">=</span> <span class="n">args</span><span class="p">.</span><span class="n">num_clients</span><span class="p">,</span>
<span class="n">min_eval_clients</span> <span class="o">=</span> <span class="n">args</span><span class="p">.</span><span class="n">num_clients</span><span class="p">,</span>
<span class="n">initial_parameters</span> <span class="o">=</span> <span class="n">fl</span><span class="p">.</span><span class="n">common</span><span class="p">.</span><span class="n">weights_to_parameters</span><span class="p">(</span><span class="n">initial_parameters</span><span class="p">),</span>
<span class="n">initial_model</span> <span class="o">=</span> <span class="n">initial_model</span><span class="p">,</span>
<span class="n">save_ckp_path</span> <span class="o">=</span> <span class="n">save_ckp_path</span><span class="p">,</span>
<span class="p">)</span>
<span class="p">)</span>
</code></pre></div></div>
<h2 id="4-client-site">4. Client Site</h2>
<p>For the client, we need to create a <code class="language-plaintext highlighter-rouge">Client</code> class that inherits from Flower’s <code class="language-plaintext highlighter-rouge">Client</code> and contains 4 methods <code class="language-plaintext highlighter-rouge">get_parameters</code>, <code class="language-plaintext highlighter-rouge">set_parameters</code>, <code class="language-plaintext highlighter-rouge">fit</code>, and <code class="language-plaintext highlighter-rouge">evaluate</code>. Then, pass the server’s IP address and its opened port, the rest is similar to traditional ML projects.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 6: Client site.
"""</span>
<span class="kn">from</span> <span class="nn">libs</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">from</span> <span class="nn">data</span> <span class="kn">import</span> <span class="n">ImageDataset</span>
<span class="kn">from</span> <span class="nn">nets</span> <span class="kn">import</span> <span class="n">LeNet5</span>
<span class="kn">from</span> <span class="nn">engines</span> <span class="kn">import</span> <span class="n">client_fit_fn</span>
<span class="k">class</span> <span class="nc">Client</span><span class="p">(</span><span class="n">fl</span><span class="p">.</span><span class="n">client</span><span class="p">.</span><span class="n">NumPyClient</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">loaders</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">device</span><span class="p">(</span><span class="s">"cpu"</span><span class="p">),</span>
<span class="n">save_ckp_path</span> <span class="o">=</span> <span class="s">"./ckp.ptl"</span><span class="p">,</span> <span class="n">training_verbose</span> <span class="o">=</span> <span class="bp">True</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">loaders</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">,</span> <span class="o">=</span> <span class="n">loaders</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span>
<span class="bp">self</span><span class="p">.</span><span class="n">num_epochs</span> <span class="o">=</span> <span class="n">num_epochs</span>
<span class="bp">self</span><span class="p">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="bp">self</span><span class="p">.</span><span class="n">save_ckp_path</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">training_verbose</span> <span class="o">=</span> <span class="n">save_ckp_path</span><span class="p">,</span> <span class="n">training_verbose</span>
<span class="bp">self</span><span class="p">.</span><span class="n">model</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">config</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">return</span> <span class="p">[</span><span class="n">value</span><span class="p">.</span><span class="n">cpu</span><span class="p">().</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">state_dict</span><span class="p">().</span><span class="n">items</span><span class="p">()]</span>
<span class="k">def</span> <span class="nf">set_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">parameters</span><span class="p">,</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">OrderedDict</span><span class="p">({</span><span class="n">key</span><span class="p">:</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">state_dict</span><span class="p">().</span><span class="n">keys</span><span class="p">(),</span> <span class="n">parameters</span><span class="p">)}),</span> <span class="n">strict</span> <span class="o">=</span> <span class="bp">True</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">parameters</span><span class="p">,</span> <span class="n">config</span>
<span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">set_parameters</span><span class="p">(</span><span class="n">parameters</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="n">history</span> <span class="o">=</span> <span class="n">client_fit_fn</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">loaders</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">,</span>
<span class="bp">self</span><span class="p">.</span><span class="n">num_epochs</span><span class="p">,</span>
<span class="bp">self</span><span class="p">.</span><span class="n">device</span><span class="p">,</span>
<span class="bp">self</span><span class="p">.</span><span class="n">save_ckp_path</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">training_verbose</span>
<span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_parameters</span><span class="p">(</span><span class="n">config</span> <span class="o">=</span> <span class="p">{}),</span> <span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"fit"</span><span class="p">].</span><span class="n">dataset</span><span class="p">),</span> <span class="n">history</span>
<span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">parameters</span><span class="p">,</span> <span class="n">config</span>
<span class="p">):</span>
<span class="k">return</span> <span class="nb">float</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"eval"</span><span class="p">].</span><span class="n">dataset</span><span class="p">)),</span> <span class="nb">len</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"eval"</span><span class="p">].</span><span class="n">dataset</span><span class="p">),</span> <span class="p">{}</span>
<span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="p">.</span><span class="n">ArgumentParser</span><span class="p">()</span>
<span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--server_address"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span> <span class="o">=</span> <span class="s">"192.168.50.102"</span><span class="p">),</span> <span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--server_port"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">int</span><span class="p">)</span>
<span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--dataset"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span> <span class="o">=</span> <span class="s">"CIFAR10"</span><span class="p">),</span> <span class="n">parser</span><span class="p">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s">"--cid"</span><span class="p">,</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">int</span><span class="p">)</span>
<span class="n">args</span> <span class="o">=</span> <span class="n">parser</span><span class="p">.</span><span class="n">parse_args</span><span class="p">()</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pandas</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s">"../datasets/{}/clients/client_{}.csv"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">,</span> <span class="n">args</span><span class="p">.</span><span class="n">cid</span><span class="p">))</span>
<span class="n">loaders</span> <span class="o">=</span> <span class="p">{</span>
<span class="s">"fit"</span><span class="p">:</span><span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">ImageDataset</span><span class="p">(</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s">"phase"</span><span class="p">]</span> <span class="o">==</span> <span class="s">"fit"</span><span class="p">],</span> <span class="n">data_path</span> <span class="o">=</span> <span class="s">"../datasets/{}/train"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">),</span>
<span class="p">),</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">shuffle</span> <span class="o">=</span> <span class="bp">True</span>
<span class="p">),</span>
<span class="s">"eval"</span><span class="p">:</span><span class="n">torch</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">ImageDataset</span><span class="p">(</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s">"phase"</span><span class="p">]</span> <span class="o">==</span> <span class="s">"eval"</span><span class="p">],</span> <span class="n">data_path</span> <span class="o">=</span> <span class="s">"../datasets/{}/train"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">),</span>
<span class="p">),</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">shuffle</span> <span class="o">=</span> <span class="bp">False</span>
<span class="p">),</span>
<span class="p">}</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">LeNet5</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="s">"MNIST"</span> <span class="ow">in</span> <span class="n">args</span><span class="p">.</span><span class="n">dataset</span> <span class="k">else</span> <span class="mi">3</span><span class="p">,</span> <span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">device</span><span class="p">(</span><span class="s">"cuda"</span> <span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s">"cpu"</span><span class="p">)</span>
<span class="n">save_ckp_path</span> <span class="o">=</span> <span class="s">"../ckps/{}/client_{}.ptl"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">dataset</span><span class="p">,</span> <span class="n">args</span><span class="p">.</span><span class="n">cid</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="s">"/"</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_ckp_path</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"/"</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])):</span>
<span class="n">os</span><span class="p">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s">"/"</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">save_ckp_path</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"/"</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">client</span> <span class="o">=</span> <span class="n">Client</span><span class="p">(</span>
<span class="n">loaders</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">device</span><span class="p">,</span>
<span class="n">save_ckp_path</span> <span class="o">=</span> <span class="n">save_ckp_path</span><span class="p">,</span> <span class="n">training_verbose</span> <span class="o">=</span> <span class="bp">True</span>
<span class="p">)</span>
<span class="n">fl</span><span class="p">.</span><span class="n">client</span><span class="p">.</span><span class="n">start_numpy_client</span><span class="p">(</span>
<span class="n">server_address</span> <span class="o">=</span> <span class="s">"{}:{}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">args</span><span class="p">.</span><span class="n">server_address</span><span class="p">,</span> <span class="n">args</span><span class="p">.</span><span class="n">server_port</span><span class="p">),</span>
<span class="n">client</span> <span class="o">=</span> <span class="n">client</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div></div>
<p>Now, everything is ready for starting. On your laptop, run the server, and on each device, run the client. As you can see, I use <code class="language-plaintext highlighter-rouge">wandb</code> to log all metrics during training. This is what they look like after 100 rounds:</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/federated-learning-iot/metrics.jpg" />
<figcaption>Figure 2. Training Loss and Accuracy. </figcaption>
</figure>
<p>Stay tuned for more content …</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://www.cs.toronto.edu/~kriz/cifar.html">[1] CIFAR10 and CIFAR100 Datasets</a><br />
<a href="https://flower.dev/">[2] Flower: A Friendly Federated Learning Framework</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page, this is part 2 of my tutorial series on deploying Federated Learning on IoT devices. In the last article, we discussed what FL is and built a network of IoT devices as well as environments for starting work. Today, I will guide you step by step to train a simple CNN model on the CIFAR10 dataset in real IoT devices by using Flower. Let’s get started.Federated Learning on IoT Devices - Part 12022-10-22T00:00:00+00:002022-10-22T00:00:00+00:00https://gather-ai.github.io/tutorials/federated-learning-iot-part-1<p>👋 Hi there. Recently, I have started working on the <a href="https://en.wikipedia.org/wiki/Federated_learning">Federated Learning</a> (FL) field, Federated Learning deployment on IoT devices in specific. In this 2-part series of tutorials, I will use a powerful framework <a href="https://flower.dev/">Flower</a> to implement a simple FL algorithm on a real network of IoT devices. Let’s get started by summarizing what FL is.</p>
<h2 id="1-background">1. Background</h2>
<h3 id="motivation">Motivation</h3>
<p>Currently, there are nearly 7 billion connected Internet of Things (IoT) devices and 3 billion smartphones around the world. These devices generate data at the edge constantly. However, due to limits in data privacy regulations and communication bandwidth, it is usually infeasible to transmit and store all training data at a central location. Coupled with the rise of Machine Learning (ML), the wealth of data collected by end devices opens up countless possibilities for meaningful research and applications.</p>
<p>From these observations, the topic of Federated Learning (FL) was introduced. FL is a distributed ML strategy that generates a global model by learning from multiple decentralized edge clients. FL enables on-device training, keeping the client’s local data private, and further, updating the global model based on the local model updates.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/federated-learning-iot/flower.jpg" />
<figcaption>Figure 1. Federated Learning Illustration. </figcaption>
</figure>
<h3 id="formulation">Formulation</h3>
<p>The federated learning problem involves learning a single, global model from data stored on tens to potentially millions of remote devices. In particular, the goal is typically to minimize the following objective function:</p>
\[\underset{\theta}{min}L(\theta)\]
<p>where</p>
\[L(\theta) := \sum_{m=1}^{M} p_mL_m(\theta)\]
<p>Here $M$ is the total number of devices, $L_m$ is the local objective function for the $m^{th}$ device, and $p_m$ specifies the relative impact of each device with $p_m \geq 0$ and $\sum_{m=1}^{M}p_m = 1$. The local objective function $L_m$ is often defined as the empirical risk over local data. The relative impact of each device $p_m$ is user-defined.</p>
<h2 id="2-a-network-of-iot-devices">2. A Network of IoT Devices</h2>
<h3 id="network-ingredients">Network Ingredients</h3>
<p>Most of the existing research on FL uses an FL setting simulation on a single machine. This does not make sense much because it does not introduce major issues of real FL like communication and system heterogeneity. In this tutorial, to introduce and handle these issues, I create and use a local network that consists of various types of edge devices. Specifically, I use 3 Raspberry Pi (RPi) 4 Model B, 2 NVIDIA Jetson Nano (Jetson) 2GB, and 5 NVIDIA Jetson Nano 4GB. Figure 2 shows the ingredients of the network where my laptop is used as a remote server and connects to all devices via local Wifi.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/federated-learning-iot/network.jpg" />
<figcaption>Figure 2. A Network of IoT Devices. </figcaption>
</figure>
<h3 id="setup-environments">Setup Environments</h3>
<p>Setup and configuring environments for smoothly working on Raspberry Pi and NVIDIA Jetson Nano is a nightmare. Therefore, I recommend using <a href="https://docs.docker.com/">Docker</a> for convenience and consistency. I have built Docker images for RPi and Jetson <a href="https://hub.docker.com/repositories">here</a>, you can pull and run them without additional installations. Make sure that you have booted Raspberry Pi OS (64-bit) for RPi and JetPack 4.6.1 for Jetson.</p>
<ul>
<li>For RPi:
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>$ docker pull lhkhiem28/rpi:1.4
$ docker run -it -d -w /root --network=host --name=rpi-container -v $(pwd):/usr/src/ lhkhiem28/rpi:1.4 /bin/bash
</code></pre></div> </div>
</li>
<li>For Jetson:
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>$ docker pull lhkhiem28/jetson:1.0
$ docker run -it -d -w /root --runtime=nvidia --network=host --name=jetson-container -v $(pwd):/usr/src/ lhkhiem28/jetson:1.0 /bin/bash
</code></pre></div> </div>
</li>
</ul>
<p>Now we are ready to start deploying FL on IoT devices, in the next part, I will implement a common FL strategy FedAvg on the above network using Flower.</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://ieeexplore.ieee.org/document/9475501">[1] A Survey on Federated Learning for Resource-Constrained IoT Devices</a><br />
<a href="https://blog.ml.cmu.edu/2019/11/12/federated-learning-challenges-methods-and-future-directions/">[2] Federated Learning: Challenges, Methods, and Future Directions</a><br /></p>Gathering.AI👋 Hi there. Recently, I have started working on the Federated Learning (FL) field, Federated Learning deployment on IoT devices in specific. In this 2-part series of tutorials, I will use a powerful framework Flower to implement a simple FL algorithm on a real network of IoT devices. Let’s get started by summarizing what FL is.Domain Generalization Tutorials - Part 52022-10-15T00:00:00+00:002022-10-15T00:00:00+00:00https://gather-ai.github.io/tutorials/domain-generalization-part-5<p>👋 Hi there. Welcome back to my page, this is part 5 of my tutorial series about the topic of Domain Generalization (DG). While all previous parts discussed DG methods that focus on the training phase, this article presents a new and unique approach that focuses on the test phase, namely <strong>test-time adjustment</strong>.</p>
<!-- You can find the source code of the whole series [here](https://github.com/lhkhiem28/DGECG).
{: .notice--info} -->
<h2 id="1-test-time-adjustment">1. Test-Time Adjustment</h2>
<p>Test-time adjustment is a novel approach to DG problems where the trained model twists its parameters to correct its prediction by itself during the test time. Since no data about the target domain is available during training in a DG setup, the existing DG methods focus on how to use labeled data from multiple-source domains. However, at test time, the model always has access to test data from the target domain. Although the available data is constrained to be:</p>
<ul>
<li>unlabeled,</li>
<li>only available online (models can not know all test cases in advance),
this data provides clues about the target distribution that is not available during training. It is natural to ask the question: How can we use the off-the-shelf unlabeled data available at test time to increase performance on the target domain?</li>
</ul>
<h2 id="2-test-time-template-adjuster">2. Test-Time Template Adjuster</h2>
<h3 id="motivation">Motivation</h3>
<p>Test-time template adjuster (T3A) is a pioneer in this approach. The method is an optimization-free procedure that adjusts the linear classifier (the last layer of deep neural networks) at test time. This procedure makes the adjusted decision boundary avoid the high-data density region on the target domain and reduces the ambiguity (entropy) of predictions, which is known to be connected to classification error. One interesting property of T3A is that it does not alter the training phase, therefore it can be used together with any existing DG algorithms. Moreover, it can be used together with any classification model since it only adjusts the linear classifier on top of the representations.</p>
<h3 id="method">Method</h3>
<p>Firstly, what is “template” in the name of T3A? Let’s say the linear classifier of a trained model is denoted as $g$ with the parameters $\theta_{g}$. $\theta_{g}$ has a shape of $dim_{z}\times C$, where $dim_{z}$ is the dimension of output from feature extractor $f$ and $C$ is the total number of classes. The template of representations for the class $k$ is defined as:</p>
\[\omega^k = \theta_{g}[:, k]\]
<p>During test time, the model generates its logits by measuring the distance (dot product) between its templates and the representations $z$ of the input data $x$, then the prediction $\widehat{y}$ is made by final operations, e.g., softmax function for multi-class classification:</p>
\[logit^k = z\omega^k\]
<p>Since these templates were trained in the source domain, there is no guarantee that they will be a good template in the target domain.</p>
<p>Next, how does T3A adjust the model templates to make better predictions on the target domain? Assume we have (batch of) test data $x$ at time $t$, T3A introduces a <em>support set</em> $\mathbb{S}_t^k$ for each class $k$:</p>
\[\begin{align}
\mathbb{S}_t^k &= \begin{cases}
\mathbb{S}_{t-1}^k \cup \{ \frac{f(x)}{\left \| f(x) \right \|} \} & \text{if $\widehat{y}=k$} \\ \mathbb{S}_{t-1}^k & \text{else}
\end{cases}
\end{align}\]
<p>where \(\left \| \cdot \right \|\) represents the L2 norm of a vector and \(\mathbb{S}_0^k = \{ \frac{\omega^k}{\left \| \omega^k \right \|} \}\). If the input data contains multiple samples at the same time (e.g., a batch of data), the above procedure is repeated for each sample in the batch.</p>
<p>Then, T3A uses centroids of these support sets as adjusted templates to make it new prediction:</p>
\[c^k = \frac{1}{\left | \mathbb{S}^k \right |}\sum_{s\in \mathbb{S}^k}s\]
<p>and</p>
\[logit^k = zc^k\]
<p>then the prediction $\widehat{y}$ is made by final operations, e.g., softmax function for multi-class classification, sigmoid and thresholding for multi-label classification.</p>
<head><style>hr.solid {border-top: 1px solid #bbb;}</style></head>
<body><hr class="solid" /></body>
<p>This is the final part of my tutorial series on Domain Generalization (DG). Actually, there is another interesting approach to DG which is based on <a href="https://en.wikipedia.org/wiki/Meta_learning_(computer_science)">Meta-Learning</a>. I might come back to this approach later.</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://arxiv.org/abs/2006.10726">[1] Tent: Fully Test-time Adaptation by Entropy Minimization</a><br />
<a href="https://proceedings.neurips.cc/paper/2021/hash/1415fe9fea0fa1e45dddcff5682239a0-Abstract.html">[2] Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page, this is part 5 of my tutorial series about the topic of Domain Generalization (DG). While all previous parts discussed DG methods that focus on the training phase, this article presents a new and unique approach that focuses on the test phase, namely test-time adjustment.Domain Generalization Tutorials - Part 42022-10-08T00:00:00+00:002022-10-08T00:00:00+00:00https://gather-ai.github.io/tutorials/domain-generalization-part-4<p>👋 Hi there. Welcome back to my page, this is part 4 of my tutorial series about the topic of Domain Generalization (DG). This article will cover the approach of <strong>domain alignment</strong>, to which most existing DG methods belong. In addition, we also cover an improvement upon this approach.</p>
<!-- You can find the source code of the whole series [here](https://github.com/lhkhiem28/DGECG).
{: .notice--info} -->
<h2 id="1-domain-alignment">1. Domain Alignment</h2>
<p>The central idea of domain alignment is to minimize the difference among source domains for learning <em>domain-invariant representations</em>. The motivation is straightforward: features that are invariant to the source domains should also generalize well on any unseen target domain. Traditionally, the difference among source domains is modeled by <a href="https://arxiv.org/abs/1612.01939">Feature Correlation</a> or <a href="https://jmlr.csail.mit.edu/papers/v13/gretton12a.html">Maximum Mean Discrepancy</a>, these entities are minimized to learn domain-invariant representations. However, let’s explore simpler and more effective domain alignment methods.</p>
<h2 id="2-domain-adversarial-training">2. Domain-Adversarial Training</h2>
<h3 id="motivation">Motivation</h3>
<p>Don’t be afraid to see the word “adversarial”, this method is simple to understand if you have read about multi-task learning in <a href="https://gather-ai.github.io/tutorials/domain-generalization-part-2/">part 2</a> of the series, but if not, it’s still simple. Domain-adversarial training (DAT) perfectly represents the spirit of the domain alignment approach, that is to learn the feature cannot tell which source domain the instance came from.</p>
<p>By leveraging a multi-task learning setting, DAT combines discriminativeness and domain-invariance into the same representations. To this end, a subtle trick is introduced along with the main method.</p>
<h3 id="method">Method</h3>
<p>Specifically, along the main task of cardiac abnormalities classification, DAT performs a subtask of domain identification and uses a gradient reversal layer to learn the representations in an adversarial manner. Figure 1 illustrates the architecture of the model and Snippet 1 describes the auxiliary module which performs DAT.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/domain-adversarial-training.jpg" />
<figcaption>Figure 1. Domain-adversarial training architecture. </figcaption>
</figure>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 1: DAT module.
"""</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">SEResNet34</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">auxiliary</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">GradientReversal</span><span class="p">(),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.2</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="n">num_domains</span><span class="p">),</span>
<span class="p">)</span>
<span class="p">...</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
<span class="p">...</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">classifier</span><span class="p">(</span><span class="n">feature</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">auxiliary</span><span class="p">(</span><span class="n">feature</span><span class="p">)</span>
</code></pre></div></div>
<p>The model is optimized with a combined loss similar to multi-task learning. Snippet 2 describes the optimization process.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 2: Optimization process.
"""</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="p">...</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">sub_logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">ecgs</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">sub_loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">binary_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">),</span> <span class="n">F</span><span class="p">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">sub_logits</span><span class="p">,</span> <span class="n">domains</span><span class="p">)</span>
<span class="p">(</span><span class="n">loss</span> <span class="o">+</span> <span class="n">auxiliary_lambda</span><span class="o">*</span><span class="n">sub_loss</span><span class="p">).</span><span class="n">backward</span><span class="p">()</span>
<span class="p">...</span>
</code></pre></div></div>
<p>Intuitively, the gradient reversal layer is skipped in the forward pass and just flips the sign of the gradient flow through it during the backpropagation process. Look at the position of this layer, it is placed right before the domain classifier $g_{d}$, this means that during training, $g_{d}$ is updated with $\frac{\partial L_{sub}}{\partial \theta_{g_d}}$ while the backbone $f$ is updated with $-\frac{\partial L_{sub}}{\partial \theta_{f}}$. In this way, the domain classifier learns how to use representations to identify the source domain of instances, but gives the reversed information to the backbone, forcing $f$ to generate domain-invariant representations.</p>
<h2 id="3-instance-batch-normalization-network">3. Instance-Batch Normalization Network</h2>
<h3 id="motivation-1">Motivation</h3>
<p>Nowadays, normalization layers are an important part of any neural network. There are many types of normalization techniques and each of them has its own characteristics and advantages, perhaps you have seen Figure 2 somewhere. We will talk about batch normalization (BN) and instance normalization (IN) here because of their effects on DG.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/normalization-techniques.jpg" />
<figcaption>Figure 2. Different normalization techniques. </figcaption>
</figure>
<p>Although BN generally works well in a variety of tasks, it consistently degrades performance when it is trained in the presence of a large domain divergence. This is because the batch statistics overfit the particular training domains, resulting in poor generalization performance in unseen target domains. Meanwhile, IN does not depend on batch statistics. This property allows the network to learn feature representations that less overfit a particular domain. The downside of IN, however, is that it makes the features less discriminative with respect to instance categories, which is guaranteed in BN in contrast. Instance-Batch normalization (I-BN) is a mixture of BN and IN, which is introduced to reap the benefits of IN of learning domain-invariant representations while maintaining the ability to capture discriminative representations from BN.</p>
<h3 id="method-1">Method</h3>
<p>Snippet 3 is a simple implementation of a one-dimensional I-BN layer, just half of BN and half of IN. It is straightforward to extend the implementation to higher-dimension usages.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 3: I-BN layer.
"""</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">Instance_BatchNorm1d</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Instance_BatchNorm1d</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">half_num_features</span> <span class="o">=</span> <span class="n">num_features</span><span class="o">//</span><span class="mi">2</span>
<span class="bp">self</span><span class="p">.</span><span class="n">BN</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">IN</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm1d</span><span class="p">(</span><span class="n">num_features</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">half_num_features</span><span class="p">),</span> <span class="n">nn</span><span class="p">.</span><span class="n">InstanceNorm1d</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">half_num_features</span><span class="p">,</span> <span class="n">affine</span> <span class="o">=</span> <span class="bp">True</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">):</span>
<span class="n">half_input</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">half_num_features</span><span class="p">,</span> <span class="n">dim</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">half_BN</span><span class="p">,</span> <span class="n">half_IN</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">BN</span><span class="p">(</span><span class="n">half_input</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">contiguous</span><span class="p">()),</span> <span class="bp">self</span><span class="p">.</span><span class="n">IN</span><span class="p">(</span><span class="n">half_input</span><span class="p">[</span><span class="mi">1</span><span class="p">].</span><span class="n">contiguous</span><span class="p">())</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">((</span><span class="n">half_BN</span><span class="p">,</span> <span class="n">half_IN</span><span class="p">),</span> <span class="n">dim</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<p>But where to place I-BN layers in a specific network, a ResNet-like model for example? Another observation showed that, for BN-based CNNs, the feature divergence caused by appearance variance (domain shift) mainly lies in the shallow half of the CNN, while the feature discrimination for categories is high in deep layers, but also exists in shallow layers. Therefore, an original ResNet can is modified as follows to become an I-BN ResNet:</p>
<ul>
<li>Only use I-BN layers in the first three residual blocks and leave the fourth block as normal (similar to MixStyle in the <a href="https://gather-ai.github.io/tutorials/domain-generalization-part-3/">previous article</a>)</li>
<li>For each selected block, only replace the BN layer after the first convolution layer in the main path with an I-BN layer</li>
</ul>
<p>Snippet 4 illustrates this setting.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 4: I-BN ResNet setting.
"""</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">SEResNet34</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span> <span class="o">=</span> <span class="n">I_NBSEBlock</span><span class="p">()</span>
<span class="p">...</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stem</span> <span class="o">=</span> <span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_0</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">True</span><span class="p">),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_3</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">False</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">False</span><span class="p">),</span>
<span class="bp">self</span><span class="p">.</span><span class="n">block</span><span class="p">(</span><span class="n">i_bn</span> <span class="o">=</span> <span class="bp">False</span><span class="p">),</span>
<span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<h2 id="4-domain-specific-i-bn-network">4. Domain-Specific I-BN Network</h2>
<h3 id="motivation-2">Motivation</h3>
<p>Domain alignment methods generally have a common limitation, which will be discussed and addressed here. Look back to an illustration of DG from <a href="https://gather-ai.github.io/tutorials/domain-generalization-part-1/">part 1</a>, where a classifier trained in <em>sketch</em>, <em>cartoon</em>, <em>art painting</em> images encounters instances from a novel domain <em>photo</em> at test-time.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/DG-DA.jpg" />
<figcaption>Figure 3. Examples from the PACS dataset for DG. Adapted from [1]. </figcaption>
</figure>
<p>It is reasonable to note that leveraging the relative similarity of the <em>photo</em> instances to instances from <em>art painting</em> might result in better predictions compared to a setting where the model relies solely on invariant characteristics across domains. Both covered methods try to learn domain-invariant representations while ignoring domain-specific features, features that are specific to individual domains.</p>
<p>Extending from the above I-BN Net, domain-specific I-BN Net (DS I-BN Net) is developed which aims to capture both domain-invariant and domain-specific features from multi-source domain data.</p>
<h3 id="method-2">Method</h3>
<p>In particular, an original ResNet can is modified to become a DS I-BN ResNet in the following two steps:</p>
<ul>
<li>Turn all BN layers in the model into domain-specific BN (<a href="https://arxiv.org/abs/1906.03950">DSBN</a>) modules</li>
<li>Replace BN layers with I-BN layers at the same positions as I-BN ResNet</li>
</ul>
<p>What is the DSBN? DSBN is a module that consists of $M$ BN layers, using parameters of each BN layer to capture domain-specific features of each individual domain in $M$ source domains. Specifically, during training, instances from domain $m$, $\mathbf{X}^{m}$ only go through the $m^{th}$ BN layer in the DSBN module. Figure 4 illustrates the module and Snippet 5 is its implementation in a one-dimensional version.</p>
<figure class="align-center" style="width: 400px">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/domain-specific-batch-normalization.jpg" />
<figcaption>Figure 4. Domain-specific batch normalization module architecture. </figcaption>
</figure>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 5: I-BN layer.
"""</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">DomainSpecificBatchNorm1d</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">num_domains</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">DomainSpecificBatchNorm1d</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">num_domains</span> <span class="o">=</span> <span class="n">num_domains</span>
<span class="bp">self</span><span class="p">.</span><span class="n">BNs</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">ModuleList</span><span class="p">(</span>
<span class="p">[</span><span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm1d</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_domains</span><span class="p">)]</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">domains</span><span class="p">,</span> <span class="n">is_training</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span> <span class="n">running_domain</span> <span class="o">=</span> <span class="bp">None</span><span class="p">):</span>
<span class="n">domain_uniques</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">unique</span><span class="p">(</span><span class="n">domains</span><span class="p">)</span>
<span class="k">if</span> <span class="n">is_training</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="p">.</span><span class="n">BNs</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="nb">input</span><span class="p">[</span><span class="n">domains</span> <span class="o">==</span> <span class="n">domain_uniques</span><span class="p">[</span><span class="n">i</span><span class="p">]])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">domain_uniques</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">BNs</span><span class="p">[</span><span class="n">running_domain</span><span class="p">](</span><span class="n">output</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
</code></pre></div></div>
<p>At inference time, a test instance is fed into all $M$ “sub-networks” of all domains to get $M$ logits. The final logit is averaged over these $M$ logits and made the prediction.</p>
<h2 id="5-results">5. Results</h2>
<p>The table below shows the performance of the two presented methods in this article:</p>
<table>
<thead>
<tr>
<th style="text-align: left"> </th>
<th style="text-align: right">Chapman</th>
<th style="text-align: right">CPSC</th>
<th style="text-align: right">CPSC-Extra</th>
<th style="text-align: right">G12EC</th>
<th style="text-align: right">Ningbo</th>
<th style="text-align: right">PTB-XL</th>
<th style="text-align: right">Avg</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">Baseline </td>
<td style="text-align: right"> 0.4290</td>
<td style="text-align: right"> 0.1643</td>
<td style="text-align: right"> 0.2067</td>
<td style="text-align: right"> 0.3809</td>
<td style="text-align: right"> 0.3987</td>
<td style="text-align: right"> 0.3626</td>
<td style="text-align: right"> 0.3237</td>
</tr>
<tr>
<td style="text-align: left">AgeReg</td>
<td style="text-align: right"> 0.4222</td>
<td style="text-align: right"> 0.1715</td>
<td style="text-align: right"> 0.2136</td>
<td style="text-align: right"> 0.3923</td>
<td style="text-align: right"> 0.4024</td>
<td style="text-align: right"> 0.4021</td>
<td style="text-align: right"> 0.3340</td>
</tr>
<tr>
<td style="text-align: left">SWA</td>
<td style="text-align: right"> 0.4271</td>
<td style="text-align: right"> 0.1759</td>
<td style="text-align: right"> 0.2052</td>
<td style="text-align: right"> 0.3969</td>
<td style="text-align: right"> 0.4313</td>
<td style="text-align: right"> 0.4203</td>
<td style="text-align: right"> 0.3428</td>
</tr>
<tr>
<td style="text-align: left">Mixup</td>
<td style="text-align: right"> 0.4225</td>
<td style="text-align: right"> 0.1759</td>
<td style="text-align: right"> 0.2127</td>
<td style="text-align: right"> 0.3901</td>
<td style="text-align: right"> 0.4025</td>
<td style="text-align: right"> 0.3934</td>
<td style="text-align: right"> 0.3329</td>
</tr>
<tr>
<td style="text-align: left">MixStyle</td>
<td style="text-align: right"> 0.4253</td>
<td style="text-align: right"> 0.1681</td>
<td style="text-align: right"> 0.2027</td>
<td style="text-align: right"> 0.3927</td>
<td style="text-align: right"> 0.4117</td>
<td style="text-align: right"> 0.3853</td>
<td style="text-align: right"> 0.3310</td>
</tr>
<tr>
<td style="text-align: left">DAT</td>
<td style="text-align: right"> 0.4282</td>
<td style="text-align: right"> 0.1712</td>
<td style="text-align: right"> 0.1966</td>
<td style="text-align: right"> 0.3956</td>
<td style="text-align: right"> 0.4114</td>
<td style="text-align: right"> 0.3878</td>
<td style="text-align: right"> <strong>0.3318</strong></td>
</tr>
<tr>
<td style="text-align: left">I-BN</td>
<td style="text-align: right"> 0.4252</td>
<td style="text-align: right"> 0.1748</td>
<td style="text-align: right"> 0.2045</td>
<td style="text-align: right"> 0.3817</td>
<td style="text-align: right"> 0.4193</td>
<td style="text-align: right"> 0.4161</td>
<td style="text-align: right"> <strong>0.3369</strong></td>
</tr>
<tr>
<td style="text-align: left">DS I-BN</td>
<td style="text-align: right"> 0.4484</td>
<td style="text-align: right"> 0.1805</td>
<td style="text-align: right"> 0.2191</td>
<td style="text-align: right"> 0.4318</td>
<td style="text-align: right"> 0.3916</td>
<td style="text-align: right"> 0.4242</td>
<td style="text-align: right"> <strong>0.3493</strong></td>
</tr>
</tbody>
</table>
<p>To be continued …</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://arxiv.org/abs/2103.02503">[1] Domain Generalization: A Survey</a><br />
<a href="https://arxiv.org/abs/1505.07818">[2] Domain-Adversarial Training of Neural Networks</a><br />
<a href="https://arxiv.org/abs/1807.09441">[3] Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net</a><br />
<a href="https://arxiv.org/abs/1907.04275">[4] Learning to Optimize Domain Specific Normalization for Domain Generalization</a><br />
<a href="https://arxiv.org/abs/2008.12839">[5] Learning to Balance Specificity and Invariance for In and Out of Domain Generalization</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page, this is part 4 of my tutorial series about the topic of Domain Generalization (DG). This article will cover the approach of domain alignment, to which most existing DG methods belong. In addition, we also cover an improvement upon this approach.Domain Generalization Tutorials - Part 32022-10-01T00:00:00+00:002022-10-01T00:00:00+00:00https://gather-ai.github.io/tutorials/domain-generalization-part-3<p>👋 Hi there. Welcome back to my page, this is part 3 of my tutorial series about the topic of Domain Generalization (DG). From this article, we will explore domain-aware approaches which take the problem domain shift into account. Today, I introduce the first family of methods which is <strong>inter-domain data augmentation</strong>.</p>
<!-- You can find the source code of the whole series [here](https://github.com/lhkhiem28/DGECG).
{: .notice--info} -->
<h2 id="1-inter-domain-data-augmentation">1. Inter-domain Data Augmentation</h2>
<p><strong>Mixing data augmentation</strong> is an emerging type of augmentation method that has shown superior in recent years. The methods of this type do a <a href="https://en.wikipedia.org/wiki/Convex_combination">convex combination</a> (mix) on two data instances at the input or feature level, hence generating a new instance for training ML models. Differing from conventional data augmentation such as crop, scale, or cut out which preserves the context of the original instance, mixing augmentation creates instances with new contexts, in other words, new domains, this is extremely suitable for solving the DG problem. Because these methods perfectly fit into mini-batch training, we have two ways to select data instances for doing the mixing, without domain labels (random shuffle mixing) and with domain labels (inter-domain mixing). Figure 1 illustrates these selection strategies. Let’s dive into Mixup and MixStyle, the two most popular and effective augmentation methods for DG.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/mixing-strategies.jpg" />
<figcaption>Figure 1. Illustration of mixing strategies. Adapted from [2]. </figcaption>
</figure>
<h2 id="2-mixup">2. Mixup</h2>
<p>As mentioned above, Mixup perfectly fits into mini-batch training, at each training iteration, we select two instances in a mini-batch following a given strategy (random shuffle or inter-domain) and then mix them at the input level through a convex combination to generate a new instance:</p>
\[x_{mix} = \lambda x + (1-\lambda ) x_{shuffled}\]
<p>where $\lambda$ is drawn from a <a href="https://en.wikipedia.org/wiki/Beta_distribution">Beta distribution</a> $\lambda \sim Beta(\alpha , \alpha )$ with $\alpha \in (0, \infty )$ is a hyper-parameter.</p>
<p>We also have to create the label for the generated instance by mixing labels of original instances in the same way:</p>
\[y_{mix} = \lambda y + (1-\lambda ) y_{shuffled}\]
<p>The above combination of original labels can yield a non-integer label for the generated instances, this is not fit with the classification problem which requires the label must be categorical. Therefore, we have to do a trick, mixing loss instead of mixing labels:</p>
\[loss = \lambda loss(logit_{mix}, y) + (1-\lambda ) loss(logit_{mix}, y_{shuffled})\]
<p>where $logit_{mix}$ is output from the model of $x_{mix}$.</p>
<p>Snippet 1 describes how to integrate Mixup into the training pipeline.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 1: Mixup integration.
"""</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">import</span> <span class="nn">pandas</span><span class="p">,</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="p">...</span>
<span class="k">if</span> <span class="n">random</span><span class="p">.</span><span class="n">random</span><span class="p">()</span> <span class="o"><</span> <span class="mf">0.5</span><span class="p">:</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">ecgs</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">binary_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">shuffled_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randperm</span><span class="p">(</span><span class="n">ecgs</span><span class="p">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">mixup_lambda</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">beta</span><span class="p">(</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">mixup_lambda</span><span class="o">*</span><span class="n">ecgs</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mixup_lambda</span><span class="p">)</span><span class="o">*</span><span class="n">ecgs</span><span class="p">[</span><span class="n">permuted_indices</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">mixup_lambda</span><span class="o">*</span><span class="n">F</span><span class="p">.</span><span class="n">binary_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">mixup_lambda</span><span class="p">)</span><span class="o">*</span><span class="n">F</span><span class="p">.</span><span class="n">binary_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">[</span><span class="n">permuted_indices</span><span class="p">])</span>
<span class="p">...</span>
</code></pre></div></div>
<h2 id="3-mixstyle">3. MixStyle</h2>
<p>Unlike Mixup which creates new instances at the input level, MixStyle is a recent method that generates new instances in the feature space by mixing their “styles”. The style of an instance is represented in its feature statistics which are mean and standard deviation across spatial dimensions in the feature space. At each iteration, we select two instances in a mini-batch following a given strategy (random shuffle or inter-domain) and then mix their styles in a similar way as Mixup:</p>
\[\mu _{mix} = \lambda \mu (x) + (1-\lambda ) \mu (x_{shuffled})\]
\[\sigma _{mix} = \lambda \sigma (x) + (1-\lambda ) \sigma (x_{shuffled})\]
<p>where $\mu$ and $\sigma$ are mean and standard deviation operations, respectively. $\lambda \sim Beta(\alpha , \alpha )$ with $\alpha \in (0, \infty )$ is a hyper-parameter.</p>
<p>Finally, the mixed feature statistics are applied to the style-normalized $x$:</p>
\[MixStyle(x) = \sigma _{mix}\frac{x - \mu (x)}{\sigma (x)} + \mu _{mix}\]
<p>If you look at the above formula carefully, you can realize that MixStyle does not actually create a new instance, but mixes the style of an instance into another one to make it become “new”. Therefore, MixStyle uses the original label $y$ of this “new” instance $x$.</p>
<p>Similar to Mixup, MixStyle is easy to implement, but where to apply MixStyle? Experiments showed that applying MixStyle after the first three residual blocks in a ResNet-like model gives the best results in our problem. Snippet 2 illustrates this setting.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 2: MixStyle setting.
"""</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">SEResNet34</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">augment</span> <span class="o">=</span> <span class="n">MixStyle</span><span class="p">(</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">p</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">)</span>
<span class="p">...</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stem</span> <span class="o">=</span> <span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_0</span> <span class="o">=</span> <span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_1</span> <span class="o">=</span> <span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_2</span> <span class="o">=</span> <span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">stage_3</span> <span class="o">=</span> <span class="p">...</span>
<span class="p">...</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">augment</span> <span class="o">=</span> <span class="bp">False</span><span class="p">):</span>
<span class="p">...</span>
<span class="n">feature</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">stem</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="n">feature</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">stage_0</span><span class="p">(</span><span class="n">inputs</span><span class="p">),</span> <span class="n">activate</span> <span class="o">=</span> <span class="n">augment</span><span class="p">)</span>
<span class="n">feature</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">stage_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">),</span> <span class="n">activate</span> <span class="o">=</span> <span class="n">augment</span><span class="p">)</span>
<span class="n">feature</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">augment</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">stage_2</span><span class="p">(</span><span class="n">inputs</span><span class="p">),</span> <span class="n">activate</span> <span class="o">=</span> <span class="n">augment</span><span class="p">)</span>
<span class="n">feature</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">stage_3</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<h2 id="4-results">4. Results</h2>
<p>The table below shows the performance of the two presented methods in this article:</p>
<table>
<thead>
<tr>
<th style="text-align: left"> </th>
<th style="text-align: right">Chapman</th>
<th style="text-align: right">CPSC</th>
<th style="text-align: right">CPSC-Extra</th>
<th style="text-align: right">G12EC</th>
<th style="text-align: right">Ningbo</th>
<th style="text-align: right">PTB-XL</th>
<th style="text-align: right">Avg</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">Baseline </td>
<td style="text-align: right"> 0.4290</td>
<td style="text-align: right"> 0.1643</td>
<td style="text-align: right"> 0.2067</td>
<td style="text-align: right"> 0.3809</td>
<td style="text-align: right"> 0.3987</td>
<td style="text-align: right"> 0.3626</td>
<td style="text-align: right"> 0.3237</td>
</tr>
<tr>
<td style="text-align: left">AgeReg</td>
<td style="text-align: right"> 0.4222</td>
<td style="text-align: right"> 0.1715</td>
<td style="text-align: right"> 0.2136</td>
<td style="text-align: right"> 0.3923</td>
<td style="text-align: right"> 0.4024</td>
<td style="text-align: right"> 0.4021</td>
<td style="text-align: right"> 0.3340</td>
</tr>
<tr>
<td style="text-align: left">SWA</td>
<td style="text-align: right"> 0.4271</td>
<td style="text-align: right"> 0.1759</td>
<td style="text-align: right"> 0.2052</td>
<td style="text-align: right"> 0.3969</td>
<td style="text-align: right"> 0.4313</td>
<td style="text-align: right"> 0.4203</td>
<td style="text-align: right"> 0.3428</td>
</tr>
<tr>
<td style="text-align: left">Mixup</td>
<td style="text-align: right"> 0.4225</td>
<td style="text-align: right"> 0.1759</td>
<td style="text-align: right"> 0.2127</td>
<td style="text-align: right"> 0.3901</td>
<td style="text-align: right"> 0.4025</td>
<td style="text-align: right"> 0.3934</td>
<td style="text-align: right"> <strong>0.3329</strong></td>
</tr>
<tr>
<td style="text-align: left">MixStyle</td>
<td style="text-align: right"> 0.4253</td>
<td style="text-align: right"> 0.1681</td>
<td style="text-align: right"> 0.2027</td>
<td style="text-align: right"> 0.3927</td>
<td style="text-align: right"> 0.4117</td>
<td style="text-align: right"> 0.3853</td>
<td style="text-align: right"> <strong>0.3310</strong></td>
</tr>
</tbody>
</table>
<p>To be continued …</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://arxiv.org/abs/1710.09412">[1] Mixup: Beyond Empirical Risk Minimization</a><br />
<a href="https://arxiv.org/abs/2104.02008">[2] Domain Generalization with MixStyle</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page, this is part 3 of my tutorial series about the topic of Domain Generalization (DG). From this article, we will explore domain-aware approaches which take the problem domain shift into account. Today, I introduce the first family of methods which is inter-domain data augmentation.Domain Generalization Tutorials - Part 22022-09-24T00:00:00+00:002022-09-24T00:00:00+00:00https://gather-ai.github.io/tutorials/domain-generalization-part-2<p>👋 Hi there. Welcome back to my page, this is part 2 of my tutorial series about the topic of Domain Generalization (DG). In this article, I will introduce the first approach to the DG problem, which I call <strong>conventional generalization</strong>.</p>
<!-- You can find the source code of the whole series [here](https://github.com/lhkhiem28/DGECG).
{: .notice--info} -->
<h2 id="1-conventional-generalization">1. Conventional Generalization</h2>
<p>Conventional generalization methods such as data augmentation or weight decay aim to make ML models less overfit on training data, therefore these models after training are assumpted to generalize well on testing data regardless of its domain. This is a great starting point to approach the DG problem. Despite the popularity of data augmentation and weight decay, I will present two more advanced and effective methods, <em>multi-task learning</em> and <em>flat minima seeking</em>.</p>
<h2 id="2-multi-task-learning">2. Multi-task Learning</h2>
<h3 id="motivation">Motivation</h3>
<p>The goal of multi-task learning (with deep neural networks) is to jointly learn one or more sub-tasks beside the main task using a shared model, therefore facilitating the model’s shared representations to be generic enough to deal with different tasks, eventually reducing overfitting on the main task. In general, sub-tasks for performing multi-task learning are defined based on specific data and problems. After that, jointly learning is established by minimizing a joint loss function.</p>
<p>Multi-task learning is popular in ML literature but rarely realized. For example, in Computer Vision, <a href="https://paperswithcode.com/task/object-detection">Object Detection</a> aims to localize and classify objects simultaneously. In Natural Language Processing, <a href="http://nlpprogress.com/english/intent_detection_slot_filling.html">Intent Detection and Slot Filling</a> aims to simultaneously identify the speaker’s intent from a given utterance and extract from the utterance the correct argument value for the slots of the intent.</p>
<h3 id="method">Method</h3>
<p>As mentioned before, sub-tasks for performing multi-task learning are defined based on specific data and problems. In our ECGs-based cardiac abnormalities classification problem, I define and perform a sub-task of <em>age regression</em> (AgeReg) from ECGs, which is feasible from a medical perspective. Figure 1 illustrates the architecture of the model and Snippet 1 describes the auxiliary module which performs regression.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/multi-task-learning.jpg" />
<figcaption>Figure 1. Multi-task learning architecture. </figcaption>
</figure>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 1: Age Regression module.
"""</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">SEResNet34</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="p">...</span>
<span class="bp">self</span><span class="p">.</span><span class="n">auxiliary</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.2</span><span class="p">),</span>
<span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="p">)</span>
<span class="p">...</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
<span class="p">...</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">classifier</span><span class="p">(</span><span class="n">feature</span><span class="p">),</span> <span class="bp">self</span><span class="p">.</span><span class="n">auxiliary</span><span class="p">(</span><span class="n">feature</span><span class="p">)</span>
</code></pre></div></div>
<p>For optimization, I use cross-entropy loss for the main classification task and L1 loss for the regression sub-task. The second loss is added to the main loss with an <code class="language-plaintext highlighter-rouge">auxiliary_lambda</code> hyperparameter, which is set to 0.02. Snippet 2 describes the optimization process. All other settings are similar to the baseline in the <a href="https://gather-ai.github.io/tutorials/domain-generalization-part-1/">previous article</a>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 2: Optimization process.
"""</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="p">...</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">sub_logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">ecgs</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">sub_loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">binary_cross_entropy_with_logits</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">),</span> <span class="n">F</span><span class="p">.</span><span class="n">l1_loss</span><span class="p">(</span><span class="n">sub_logits</span><span class="p">,</span> <span class="n">ages</span><span class="p">)</span>
<span class="p">(</span><span class="n">loss</span> <span class="o">+</span> <span class="n">auxiliary_lambda</span><span class="o">*</span><span class="n">sub_loss</span><span class="p">).</span><span class="n">backward</span><span class="p">()</span>
<span class="p">...</span>
</code></pre></div></div>
<h2 id="3-flat-minima-seeking">3. Flat Minima Seeking</h2>
<h3 id="motivation-1">Motivation</h3>
<p>In optimization, the connection between different types of local optima and generalization has been explored extensively in many studies [2]. These studies show that sharp minima often lead to larger test errors while flatter minima yield better generalization. This finding raised a new research direction in deep learning that seeks out flatter minima when training neural networks.</p>
<p>The two most popular flatness-aware solvers are Sharpness-Aware Minimization (SAM) and Stochastic Weight Averaging (SWA). SAM is a procedure that simultaneously minimizes loss value and loss sharpness, this procedure finds flat minima directly but also doubles training cost. Meanwhile, SWA finds flat minima by a weight ensemble approach and has almost no computational overhead.</p>
<h3 id="method-1">Method</h3>
<p>Intuitively, SWA updates a pre-trained model (namely, a model trained with sufficiently enough training epochs, $K_0$) with a cyclical or high constant learning rate scheduling. SWA gathers model parameters for every $K$ epoch during the update and averages them for the model ensemble. SWA finds an ensembled solution of different local optima found by a sufficiently large learning rate to escape a local minimum.</p>
<p>Since 2020, SWA was included in <a href="https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/">PyTorch</a> effectively. We need two ingredients to apply SWA to our model, a <code class="language-plaintext highlighter-rouge">swa_model</code> and a <code class="language-plaintext highlighter-rouge">swa_scheduler</code>. Snippet 3 illustrates how to initialize these two entities in PyTorch. Figure 2 shows the whole learning rate schedule during training, where $K_0$ is set to <code class="language-plaintext highlighter-rouge">T_max</code> of the base scheduler.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 3: Initializing swa_model and swa_scheduler.
"""</span>
<span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="n">optim</span>
<span class="p">...</span>
<span class="n">swa_model</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">swa_utils</span><span class="p">.</span><span class="n">AveragedModel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="n">swa_scheduler</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">swa_utils</span><span class="p">.</span><span class="n">SWALR</span><span class="p">(</span>
<span class="n">optimizer</span><span class="p">,</span> <span class="n">swa_lr</span> <span class="o">=</span> <span class="mf">1e-2</span><span class="p">,</span>
<span class="n">anneal_strategy</span> <span class="o">=</span> <span class="s">"cos"</span><span class="p">,</span> <span class="n">anneal_epochs</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/lr-schedule.jpg" />
<figcaption>Figure 2. Learning rate schedule during training. </figcaption>
</figure>
<p>Snippet 4 below briefs the training loop. It is a little bit tricky when applying SWA to ML models that have BatchNorm layers, we need to use a utility function <code class="language-plaintext highlighter-rouge">update_bn</code> to compute the BatchNorm statistics for the SWA model on a given data loader.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Snippet 4: Training loop.
"""</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="p">...</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">epoch</span> <span class="o">></span> <span class="n">scheduler</span><span class="p">.</span><span class="n">T_max</span><span class="p">:</span>
<span class="n">scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">swa_model</span><span class="p">.</span><span class="n">update_parameters</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">())</span>
<span class="n">swa_scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
<span class="p">...</span>
<span class="p">...</span>
<span class="n">optim</span><span class="p">.</span><span class="n">swa_utils</span><span class="p">.</span><span class="n">update_bn</span><span class="p">(</span><span class="n">loaders</span><span class="p">[</span><span class="s">"train"</span><span class="p">],</span> <span class="n">swa_model</span><span class="p">)</span>
<span class="p">...</span>
</code></pre></div></div>
<h2 id="4-results">4. Results</h2>
<p>The table below shows the performance of the two presented methods in this article:</p>
<table>
<thead>
<tr>
<th style="text-align: left"> </th>
<th style="text-align: right">Chapman</th>
<th style="text-align: right">CPSC</th>
<th style="text-align: right">CPSC-Extra</th>
<th style="text-align: right">G12EC</th>
<th style="text-align: right">Ningbo</th>
<th style="text-align: right">PTB-XL</th>
<th style="text-align: right">Avg</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">Baseline </td>
<td style="text-align: right"> 0.4290</td>
<td style="text-align: right"> 0.1643</td>
<td style="text-align: right"> 0.2067</td>
<td style="text-align: right"> 0.3809</td>
<td style="text-align: right"> 0.3987</td>
<td style="text-align: right"> 0.3626</td>
<td style="text-align: right"> 0.3237</td>
</tr>
<tr>
<td style="text-align: left">AgeReg</td>
<td style="text-align: right"> 0.4222</td>
<td style="text-align: right"> 0.1715</td>
<td style="text-align: right"> 0.2136</td>
<td style="text-align: right"> 0.3923</td>
<td style="text-align: right"> 0.4024</td>
<td style="text-align: right"> 0.4021</td>
<td style="text-align: right"> <strong>0.3340</strong></td>
</tr>
<tr>
<td style="text-align: left">SWA</td>
<td style="text-align: right"> 0.4271</td>
<td style="text-align: right"> 0.1759</td>
<td style="text-align: right"> 0.2052</td>
<td style="text-align: right"> 0.3969</td>
<td style="text-align: right"> 0.4313</td>
<td style="text-align: right"> 0.4203</td>
<td style="text-align: right"> <strong>0.3428</strong></td>
</tr>
</tbody>
</table>
<p>To be continued …</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://arxiv.org/abs/2009.09796">[1] Multi-task Learning with Deep Neural Networks: A Survey</a><br />
<a href="https://arxiv.org/abs/1609.04836">[2] On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima</a><br />
<a href="https://arxiv.org/abs/2010.01412">[3] Sharpness-Aware Minimization for Efficiently Improving Generalization</a><br />
<a href="https://arxiv.org/abs/1803.05407">[4] Averaging Weights Leads to Wider Optima and Better Generalization</a><br />
<a href="https://arxiv.org/abs/2102.08604">[5] SWAD: Domain Generalization by Seeking Flat Minima</a><br /></p>Gathering.AI👋 Hi there. Welcome back to my page, this is part 2 of my tutorial series about the topic of Domain Generalization (DG). In this article, I will introduce the first approach to the DG problem, which I call conventional generalization.Domain Generalization Tutorials - Part 12022-09-17T00:00:00+00:002022-09-17T00:00:00+00:00https://gather-ai.github.io/tutorials/domain-generalization-part-1<p>👋 Hi there. I’m Khiem. Welcome to my page, where I gather and share some intuitive explanations and hands-on tutorials on a range of topics in AI.</p>
<p>🚀 I am going to kick off this website with a series of tutorials about the topic of Domain Generalization. This series provides a systematic survey of outstanding methods in literature and my own implementations to demonstrate these methods. This is the first part of the series that gives you a brief understanding of the term Domain Generalization. Let’s get started.</p>
<!-- You can find the source code of the whole series [here](https://github.com/lhkhiem28/DGECG).
{: .notice--info} -->
<h2 id="1-background">1. Background</h2>
<h3 id="motivation">Motivation</h3>
<p>Machine Learning (ML) systems generally rely on an over-simplified assumption, that is, the training (source) and testing (target) data are independent and identically distributed (i.i.d.), however, this assumption is not always true in practice. When the distributions of training data and testing data are different, which is referred to as the domain shift problem, the performance of these ML systems often catastrophically decreases due to domain distribution gaps. Moreover, in many applications, target data is difficult to obtain or even unknown before deploying the model. For example, in biomedical applications where data differs from equipment to equipment and institute to institute, it is impractical to collect the data of all possible domains in advance.</p>
<p>To address the domain shift problem, as well as the absence of target data, the topic of Domain Generalization (DG) was introduced. Specifically, the goal in DG is to learn a model using data from a single or multiple related but distinct source domains in such a way that the model can generalize well to any <strong><em>unseen</em></strong> target domain.</p>
<p class="notice"><strong>Watch out!</strong> Unlike other related topics such as Domain Adaptation (DA) or Transfer Learning (TL), where the ML models can do some forms of adaptation on target data, DG considers more ubiquitous scenarios in practice where target data is inaccessible during model learning.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/DG-DA.jpg" />
<figcaption>Figure 1. Examples from the PACS dataset for DG. Adapted from [1]. </figcaption>
</figure>
<h3 id="formulation">Formulation</h3>
<p>In the context of DG, we have access to $M$ similar but distinct source domains \(S_{source}=\{S_m=\{(x, y)\}\}_{m=1}^M\), each associated with a joint distribution \(P_{XY}^{(m)}\) with:</p>
<ul>
<li>\(P_{XY}^{(m)}\neq P_{XY}^{({m}')}\) with \(m\neq {m}'\) and \(m, {m}'\in \{1, ..., M\}\)</li>
<li>\(P_{Y\mid X}^{(m)}= P_{Y\mid X}^{({m}')}\) with \(m\neq {m}'\) and \(m, {m}'\in \{1, ..., M\}\)</li>
</ul>
<p>and we have to minimize prediction error on an unseen target domain \(S_{target}\) with:</p>
<ul>
<li>\(P_{XY}^{(target)}\neq P_{XY}^{(m)}\) with \(m\in \{1, ..., M\}\)</li>
<li>\(P_{Y\mid X}^{(target)}= P_{Y\mid X}^{(m)}\) with \(m\in \{1, ..., M\}\)</li>
</ul>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/DG-formulation.jpg" />
<figcaption>Figure 2. Illustration of DG. Adapted from [1]. </figcaption>
</figure>
<h2 id="2-tutorial-settings">2. Tutorial Settings</h2>
<p>In this series of tutorials, besides introducing and explaining outstanding DG methods intuitively, I also prepare those implementations and practice them on a real-world problem, which is classifying cardiac abnormalities from twelve-lead <a href="https://en.wikipedia.org/wiki/Electrocardiography">ECGs</a> (see more details in <a href="https://moody-challenge.physionet.org/2021/">PhysioNet Challenge 2021</a>). This can help you understand better as well as apply these methods to your own problems immediately.</p>
<h3 id="datasets">Datasets</h3>
<p>The datasets are from PhysioNet Challenge 2021, containing twelve-lead ECG recordings from 6 institutes in 4 countries across 3 continents. Each recording was annotated with one or more of 26 types of cardiac abnormalities, which means the problem is <em>multi-label classification</em>. Figure 3 shows the number of data samples in each dataset, and Figure 4 illustrates the difference in the appearance of signals from 6 institutes.</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/data-sources.jpg" />
<figcaption>Figure 3. The number of data samples in each dataset. </figcaption>
</figure>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/signal-appearance.jpg" />
<figcaption>Figure 4. The difference in the appearance of signals from 6 institutes. </figcaption>
</figure>
<p>I recommend you read some documents to understand what is ECG and cardiac abnormalities, as well as our problem:</p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Electrocardiography">https://en.wikipedia.org/wiki/Electrocardiography</a></li>
<li><a href="https://www.who.int/health-topics/cardiovascular-diseases">https://www.who.int/health-topics/cardiovascular-diseases</a></li>
<li><a href="https://arxiv.org/abs/2207.12381">https://arxiv.org/abs/2207.12381</a></li>
</ul>
<p>Important things to remember about our problem:</p>
<ul>
<li>Input: twelve 1D-signals, a matrix with a shape of (12, 5000)</li>
<li>Output: one or more of 26 classes, a vector of 26 elements, each 0 or 1</li>
</ul>
<h3 id="baseline">Baseline</h3>
<p>We always need a baseline model before applying any advanced methods. Here, I use:</p>
<ul>
<li>One-dimensional <a href="https://arxiv.org/abs/1709.01507">SEResNet34</a> model</li>
<li>Cross Entropy Loss function</li>
<li>Adam optimizer with <code class="language-plaintext highlighter-rouge">lr</code> = 1e-3 and <code class="language-plaintext highlighter-rouge">weight_decay</code> = 5e-5</li>
<li>Cosine Annealing scheduler with <code class="language-plaintext highlighter-rouge">eta_min</code> = 1e-4 and <code class="language-plaintext highlighter-rouge">T_max</code> = 50</li>
<li>The batch size is 512 and the number of epochs is 80</li>
</ul>
<p>Evaluation of DG algorithms often follows the <em>leave-one-domain-out</em> rule. It leaves one dataset as the target domain while treating the others as the training part, i.e., source domains. Based on this evaluation strategy, the baseline model’s performance is shown in the table below:</p>
<figure class="align-center">
<img src="https://gather-ai.github.io/assets/images/domain-generalization/leave-one-domain-out.jpg" />
<figcaption>Figure 5. Leave-one-domain-out evaluation strategy. </figcaption>
</figure>
<table>
<thead>
<tr>
<th style="text-align: left"> </th>
<th style="text-align: right">Chapman</th>
<th style="text-align: right">CPSC</th>
<th style="text-align: right">CPSC-Extra</th>
<th style="text-align: right">G12EC</th>
<th style="text-align: right">Ningbo</th>
<th style="text-align: right">PTB-XL</th>
<th style="text-align: right">Avg</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">Baseline </td>
<td style="text-align: right"> 0.4290</td>
<td style="text-align: right"> 0.1643</td>
<td style="text-align: right"> 0.2067</td>
<td style="text-align: right"> 0.3809</td>
<td style="text-align: right"> 0.3987</td>
<td style="text-align: right"> 0.3626</td>
<td style="text-align: right"> <strong>0.3237</strong></td>
</tr>
</tbody>
</table>
<p>Now we are ready to start exploring DG methods, next part of the series will present the first approach.</p>
<h2 id="references">References</h2>
<p style="font-size: 14px;"><a href="https://arxiv.org/abs/2103.03097">[1] Generalizing to Unseen Domains: A Survey on Domain Generalization</a><br />
<a href="https://arxiv.org/abs/2103.02503">[2] Domain Generalization: A Survey</a><br /></p>Gathering.AI👋 Hi there. I’m Khiem. Welcome to my page, where I gather and share some intuitive explanations and hands-on tutorials on a range of topics in AI.