1 Introduction

The research on deep learning has brought remarkable progress in many interdisciplinary fields, such as medical image analysis Litjens et al. (2017) Lundervold and Lundervold, (2019) Zhou et al. (2021). Lots of novel designated neural network architectures and corresponding end-to-end solutions have been proposed to assist diagnosis and treatment planning in real-world clinical applications, with the goal of providing better and more intelligent healthcare services by improving the diagnosis efficiency of doctors and reducing the treatment cost of patients. These techniques become prevalent, especially in two-dimensional (2D) or three-dimensional (3D) volumetric medical image analysis. For example, for lung disease diagnosis, Gordienko et al. (2018) in their study performed lung segmentation in X-ray scans with a UNet-based Ronneberger et al. (2015) structure. A multi-scale convolutional neural network (MCNN) was proposed by Shen et al. (2015) to detect lung nodules in computed tomography (CT) slices with 3D volumetric images. More applications are also investigated, such as in liver segmentation by Christ et al. (2017) and Long et al. (2015) and breast cancer diagnosis, brain magnetic resonance imaging (MRI) analysis, etc. by Bejnordi et al. (2017). Besides regular 2D or 3D grid images, geometrical medical data, such as 3D meshes, point clouds, or molecules, also become even more popular in many medical applications, while satisfactory intelligent data analysis algorithms are yet under construction.

In dental research, intra-oral scanners (IOSs) have been popularized and extensively used in digital orthodontics. They can generate a 3D mesh of teeth anatomy, which is more accurate than a plaster model. In clinical diagnosis, an important step is to segment the individual tooth and gingiva precisely in the IOS meshes acquired from the scanners, i.e., 3D tooth segmentation. Specifically, given an IOS mesh consisting of more than 150,000 triangulated faces, 3D tooth segmentation would predict and classify each face to corresponding teeth or gingiva following the FDI Herrmann, (1967) notation. The segmented outputs are an indispensable prerequisite for subsequent steps, such as diagnosis and treatment planning in orthodontics and implanting. Several pioneering researchers carried out by Xu et al. (2019), Tian et al. (2019), Zanjani et al. (2019), Cui et al. (2021), Lian et al. (2020), and Hao et al. (2021) have worked on the 3D tooth segmentation problem with geometry-based or deep-learning-based methods, such as models based on conventional 2D or 3D convolutional neural networks (CNN) or designated tooth segmentation networks. These methods achieve good performance on their individual test set, but there still exist some limitations as elaborated below.

Accurate and automatic 3D tooth segmentation remains a challenging task on the following grounds. On one hand, the IOS samples vary significantly among patients, such as different tooth shapes (tooth with cavities or defective restoration or attrition), tooth numbers (missing teeth, hypodontia, or hyperdontia), tooth sizes (microdontia and macrodontia), positional varieties (tipped, rotated, or shifted tooth and crowding teeth), leading to a large data heterogeneity among patients samples Hao et al. (2021). Such a large data heterogeneity imposes serious challenges to developing robust and accurate 3D tooth segmentation solutions. On another hand, most of these methods are only evaluated with their own test set due to the lack of publicly available large-scale and multi-centric IOS datasets. For example, the method given in a study by Cui et al. (2021) is only evaluated with less than 50 samples collected from patients without third molars from the same center. Previous research also demonstrated that their performance would degrade a lot if they are evaluated on a large-scale dataset with more complicated cases Hao et al. (2021). Following the aforementioned two constraints, it is worthwhile if we could aggregate data samples from multiple hospitals and clinics, i.e., collecting much more data samples with high heterogeneity as they are collected from different sites and patients. However, data sharing and exchange among hospitals and clinics might be awkward due to privacy and regulatory concerns. Overcoming such data island issues is of high necessity to achieve clinically applicable solutions for 3D tooth segmentation in multi-centric scenarios.

Recently, the federated learning (FL) framework is proposed for collaborative and distributed learning across multiple participants (such as hospitals) without explicit data sharing McMahan et al. (2017). Participants, which are also termed as clients, no matter whether in large hospitals or small clinics, can utilize their computation resources to perform training based on premium local datasets, and share their model parameters to a central server. Within this collaborative mode, clients could contribute to the same global model and significantly boost the model performance without exchanging their local data. FL has been applied to many fields, including mobile-edge computing Lu et al. (2020) and the Internet of Things (IoT) Ren et al. (2019). There is also some pioneering work about federated learning in medical image analysis Kaissis et al. (2020); Rieke et al. (2020); Fan et al. (2021); Warnat-Herresthal et al. (2021). For example, Liu Q. et al. (2020) in their study improved prostate segmentation by learning the shared knowledge from heterogeneous datasets in multiple sites. However, to the best of our knowledge, there is no previous work exploring the feasibility of FL for 3D tooth segmentation due to challenges in 3D geometrical medical image analysis, though the demand becomes burning with the drastically increasing number of dental patients.

In this study, we propose the framework FedTSeg based on the general FL framework for distributed 3D tooth semantic segmentation with a privacy-preserved module under various settings. We first formulated the 3D tooth segmentation as a point cloud segmentation task and designed the corresponding segmentation architecture based on the EdgeConv blocks Wang et al. (2019). Under the general FedAvg setting McMahan et al. (2017) that can easily scale to a large-scale dataset with competing performance, we investigated the tooth segmentation performance of the FedTSeg with balanced or imbalanced distributions of data samples among clients. We also study the effect with different numbers of clients with heterogeneous IOS samples. Furthermore, to resist the potential parameters leakage during the federated process, we adopted a homomorphic privacy-preserving module in FedTSeg to strictly protect the communication between clients and the server. Comprehensive experiments with 500 IOS samples demonstrate that FedTSeg can achieve a mean intersection of union (mIoU), dice coefficient (DSC), and accuracy (ACC) of 81.49, 86.42, and 92.53%, respectively, significantly outperforming the conventional counterparts trained with a local paradigm. Moreover, the overall performance with FedTSeg is on par with the central model trained with the aggregated data from all clients with privacy-preserved distributed learning. Our work presents the first attempt in federated learning for 3D tooth segmentation over geometric medical data, demonstrating the strong potential of federated learning for challenging 3D medical image analysis tasks in the distributed multi-centric setting.

Our main contributions can be summarized as follows:

• We established the federated tooth segmentation (FedTSeg) framework based on the deep graph convolutional neural networks for privacy-preserved distributed 3D tooth segmentation and investigated the IOS tooth segmentation performance under various settings, such as a varying number of clients or different distributions of heterogeneous IOS mesh scans.

• We achieved privacy-preserving federated 3D tooth segmentation with a homomorphic encryption mechanism to prevent potential parameter leakage during communication.

• We demonstrated the effectiveness of the proposed FedTSeg framework with comprehensive experiments, which exhibit that FedTSeg could attain a better global model than conventional local training. Meanwhile, FedTSeg strictly protects the patient’s privacy and secures the communication between clients and the central server.

The rest of the study is organized as follows. The related work is reviewed in Section 2, a system model used in this study is introduced in Section 3, methods are elaborated in Section 4, experiments performed are described in Section 5, the discussion and analysis of this work are given in Section 6,and conclusions of this study are given in Section 7.

2 Related Work

2.1 3D Tooth Segmentation

We formulate the 3D tooth segmentation task as a 3D point cloud segmentation task, which is a specific branch of 3D shape segmentation. There has been substantial work for 3D shape segmentation. PointNet Qi et al. (2017a) and PointNet++Qi et al. (2017b) are widely used when it comes to dealing with point clouds. They support semantic segmentation of objects and scenes, but cannot capture the geometrical relationships between points, leading to inferior performance in complex scenes. In their study, Wang et al. (2019) came up with the dynamic graph convolutional neural work (DGCNN) based on the EdgeConv block, which can obtain both local and global representations. However, the performance of these methods for 3D tooth segmentation is not good as expected, as the IOS tooth mesh, with higher-resolution and complicated anatomical structures, is significantly different from nature objects. Recently, there have been several works toward improved performance of 3D tooth segmentation based on point cloud segmentation. Some methods first extract predefined geometrical features from the original mesh and then resolve the 3D tooth segmentation task via 2D/3D convolutional neural network (CNN) Xu et al. (2019) and Tian et al. (2019). Some specifically designed neural networks are also proposed to improve the performance of tooth segmentation, such as DC-Net Hao et al. (2021), TSegNet Cui et al. (2021), and MeshSegNet Lian et al. (2020). They achieve precise segmentation on the regular upper jaw or lower jaw scans but perform poorly if restricted to limited training data or tested with heterogeneous IOS samples, while the large-scale annotated dataset is not publicly available. In our work, we employ the FL framework with a segmentation backbone composed of EdgeConv blocks and CNNs, which can take advantage of distributed IOS samples from multiple hospitals and clinics for local training and global aggregation, to obtain a well-trained global model while getting around the privacy concerns.

2.2 Federated Learning

Federated learning was first raised to conduct distributed training from decentralized data across various client devices. In contrast to conventional centralized learning paradigms, FL does not require explicit data sharing among clients or institutions. A typical algorithm is federated averaging (FedAvg) McMahan et al. (2017), which gives a general platform consisting of a central server and a number of distributed clients. More advanced FL frameworks are also proposed to tackle different issues in FL, such as robustness, privacy, and heterogeneity [FedBN Li et al. (2021), FedProx Li et al. (2020), MFL Liu W. et al. (2020)]. But there still exist some challenges preventing federated learning from being applicable, such as performance decline in non-iid settings and heterogeneous data distribution, low communication efficiency under huge traffic pressure, and potential privacy leakage. In clinical situations, there have been some recent works adopting the federated learning framework for specific medical tasks, such as brain-tumor segmentation Sheller et al. (2020), COVID-19 screening Feki et al. (2021), and prostate cancer classification Yan et al. (2021). Besides, Kaissis et al. (2020) in their study presented an overview of cutting-edge secure methods in federated learning with medical imaging. Fan et al. (2021) in their study provided an FL framework for 3D brain MRI images. In their study, Warnat-Herresthal et al. (2021) developed a decentralized edge-computing framework for medical imaging with a permitted blockchain. In this work, we focus on the challenging 3D tooth segmentation task that is not investigated under the FL setting. The previous work by Yeom et al. (2018), Melis et al. (2019), and Song et al. (2020) has shown that an unreliable server could deduce the feature of training data via reverse engineering during model updating. These attackers could conduct inferences about label information and features of local datasets through gradient information uploaded by clients. Thus, secure protection for model parameters is needed during the communication between clients and servers. From the perspective of protecting patients’ privacy, we further include the homomorphic encryption mechanism to strengthen privacy during parameter exchange between clients and the server.

2.3 Federated Learning for Medical Image Segmentation

There has been previous work focusing on improving the performance of segmentation with a federated learning system. In their work, Li et al. (2019) proposed an FL framework with differential privacy for brain-tumor segmentation to preserve patient data privacy. In their work, Bercea et al. (2021) proposed a framework for federated unsupervised brain anomaly segmentation. In their work, Lo et al. (2021) showed that a federated learning model could achieve similar results as models trained on fully centralized data for microvasculature segmentation. While these methods mainly focus on regular 2D segmentation tasks for grid images, this work, in contrast, investigates the more challenging 3D tooth segmentation task over complicated and heterogeneous geometrical medical data. To the best of our knowledge, there exists no prior work on federated learning for 3D medical image segmentation tasks, while our work presents the initial step on FL for tooth segmentation on large-scale heterogeneous 3D IOS meshes.

3 Problem Formulation and System Model

We aim to solve the 3D tooth segmentation task by simulating cooperation among a large group of medical institutions with a federated learning framework. Let us first define the 3D tooth segmentation task on IOS meshes concretely. Let m = (V, F) denote an IOS mesh, where V is the vertices and F is the triangular faces of the mesh. The goal of 3D IOS tooth segmentation is to assign a corresponding label yt for each triangular faces ft, where

yt0,1118,2128,3138,4148

denotes the gingiva and FDI Herrmann, (1967) notations for 32 permanent teeth, respectively.

Our system model is illustrated in Figure 1. The system contains three parts: n distributed clients, a global server, and an independent encryption authority component. The clients represent medical institutions participating in federated learning, such as clinics or hospitals, and the IOS scans from patients are securely stored in their local database. For large institutions, such as public hospitals, the amount and diversity of the dataset would be larger than small institutions, such as clinics, which is simulated in our experiments as well. Within this framework, the clients can make use of their data and computation capability to perform local training and participate in a federated learning process by sending their local model parameters wi to a global server. The global server will aggregate the parameters from the distributed clients and send the updated parameters w to the local clients.

www.frontiersin.org

FIGURE 1. A system model of our FL architecture. ω1, ω2, …ωn represent the encrypted local model parameters of each client. ωt+1 represents the updated global model’s parameter after aggregation.

The federated learning process is further secured by an additional homomorphic encryption authority component. The encryption authority is a third party independent from the server. Before sending local parameters, clients will request public and private keys from an encryption authority, and use the public key to encrypt their local model’s parameters. The encrypted model parameters from each client are aggregated by the server and distributed back to each client in one communication round. By doing so we can prevent information leakage to preserve privacy during federated communications, such as avoiding leakage of gradients that can be used to reconstruct the original training data. The detailed algorithm of federated learning and homomorphic mechanism will be introduced in the following sections.

4 Methods

In this section, we systematically introduce the proposed FedTSeg framework. We first introduce the federated learning framework, which consists of the clients and server modules. Next, we present the details of our segmentation backbone for 3D IOS tooth segmentation, which is built with the EdgeConv blocks and inspired by the DGCNN. Finally, we mathematically define the homomorphic encryption process that can help secure the federated learning process.

4.1 Federated Learning Framework

The FedTSeg framework is mainly based on the general FedAvg architecture, which can easily scale to large-scale datasets. Below is a detailed description of the federated learning framework with multiple distributed clients, e.g., hospitals and clinics that can perform model training locally, and a global server that aggregates models from the clients.

4.1.1 Client Module

Assume there are n clients participating in the federated tooth segmentation, i.e.,

C={c1,c2,cn}

, where

C

denotes all the clients, and ck denotes the k-th client. Each client ck is associated with a local dataset

Dk={dik}i=1Nk

, where

Nk=Dk

denotes the size of the local dataset. Each data sample is defined as

dik=(xi,yi)

, where xi represents the IOS scan from the i-th patient, and yi denotes the corresponding annotated labels for each mesh face.

We define the feature extraction and learning process in each client. As mentioned above, we transform the original segmentation task over the 3D IOS surfaces to a segmentation task over 3D point clouds following. Suppose we have an IOS mesh m = (V, F), where V are the vertices and F are the faces of the mesh. We first randomly sample N face centers from all triangular faces and extract a 15-dimensional feature for each point, leading to a point cloud

P=pi|i=1N

. In particular, for each triangular face ft, we have three vertices

v=(xi,yi,zi)|i=13

, where x, y, and z are the 3D Cartesian coordinates of the vertex. The center of the triangle is denoted as

vc=1313(xi,yi,zi)

. We further extract more geometrical features from the original mesh for each point. Specifically, we compute the normal vector hn of the triangular face and a structure descriptor

hsR9

. With three vertices v and triangle center vc, we have

hs=concat(vivci=13)

, where concat () means concatenation of the vectors. Hence, each point pi is associated with a feature vector

h=vc,hs,hnR15

. Besides, each point pi is also associated with a corresponding label yi that assigns each face into different tooth codes, where

yi0,1118,2128,3138,4148

denotes the gingiva and FDI (Federation Dentaire Internationale) notations for 32 permanent teeth. In our framework, each client follows the same approach to perform feature extraction.

Before communicating with the central server, clients will conduct local training for E epochs based on the segmentation neural network, which is described in Section 4.2. The goal of each client is to find the local model’s parameters, ω, that minimize the loss function i (ω; xi, yi), which can quantify the distance between the predicted labels

ŷi

(determined by ω) and true labels yi, as defined in Section 4.2. Consider all samples in the local dataset, the objective function for client k is

Lkω=1NkiDkiω;xi,yi.(1)

During local training, each client updates their local parameters using the stochastic gradient descent (SGD) method: ωωηLk(ω), where η is the learning rate, and ∇Lk(ω) represents the average gradient on local training. An advanced SGD optimizer, such as Adam by Kingma and Ba (2015), might be used for better convergence and performance during training. When the t-th round of local training is finished, the client will request a public key and a private key from the encryption authority, and encrypt the local model parameter

ωtk

with the public key. Then, the client will upload their ciphertext (local-encrypted parameters) to the server.

Upon receiving the updated encrypted message from the server, the client will decipher the message with the private key and get the updated model parameters

ωt+1k

. Afterward, it will begin the next round of local training. Due to the encryption–decryption mechanism, the model parameters are under strict protection during the communication period.

4.1.2 Server Module

The global server is responsible for aggregating model parameters from each client and distributing the updated global model back to clients. In the vanilla FedAvg framework, the server will collect the local model parameters from each client, perform weighted average (Eq. 2) operations, and then send the updated parameters back to each client. Here, we define the model aggregation process as follows:

ωt+1kk=1nNkNωtk,whereN=k=1nNk.(2)

In our FedTSeg framework, we modified the Eq. 2 to incorporate the encryption mechanism to ensure secure parameter exchange. Let ζk denote the encrypted local model parameters. The server will update new model parameters ζt+1 with Eq. 3 following equation:

ζt+1kk=1nNkζtk,(3)

which excludes the division operation. The reason is that the additive homomorphism does not support the multiplication between an encrypted message and a float number. It supports the addition between two ciphertexts, easily inferring that it also supports the multiplication between ciphertext and a non-negative integer. Hence, in our FedTSeg setting, the server will send the number

N

back to each client, and ask each client to do the division operation locally.

4.2 3D Tooth Segmentation Network Architecture

In this work, we design the 3D tooth segmentation neural network based on the EdgeConv blocks as inspired by the DGCNN Wang et al. (2019), which is capable of processing point clouds and can be trained and evaluated in an end-to-end manner. The network architecture is illustrated in Figure 2.

www.frontiersin.org

FIGURE 2. The architecture of our 3D tooth segmentation network: Input mesh → Point cloud

xRn×f

Transform net → EdgeConv → EdgeConv → EdgeConv → Conv2D[1024] → Maxpool → Conv2D[256] → Dropout → Conv2D[256] → Dropout → Conv2D[128] → Output. n denotes the number of sample points, and f denotes the number of features. In our implementation we set n = 10, 000 and f = 15. The segmentation output

yRn×p

gives pointwise prediction for p semantic labels. ⊗ represents concatenation. Conv2D[64] denotes a 2D convolutional layer with 64 filters.

4.2.1 Transform Net

We first align each input point cloud to canonical space with a transform Net Qi et al. (2017a), which will be fed into subsequent EdgeConv blocks for further representation learning and segmentation. Specifically, we sample n points on the input IOS mesh to generate the feature vector x with a shape of n × f, where f = 15 denotes the dimension of features as defined in the feature extraction process above. In our implementation, we set n = 10, 000. The transform net is composed of consecutive 2D convolutional layers (Conv2D), a max-pooling layer, and fully connected layers (FC). In particular, we use a transform net with four Conv2D layers with 64/128/128/1024 filters, respectively. Each convolutional layer uses a 1 × 1 kernel and a stride of 1, followed by batch normalization and ReLU activation. The output of the convolution layers is processed with a max-pooling layer, followed by two consecutive FC layers with 256/512 hidden units. Finally, the input point cloud is transformed into a canonical space by a transformation matrix that is estimated by the transform net.

4.2.2 Network Architecture

The segmentation backbone is built on top of the EdgeConv block, which is illustrated in the bottom of Figure 2. The EdgeConv can capture local geometric features while preserving permutation invariance Wang et al. (2019). Stacking multiple EdgeConv layer can further capture high-level semantic features. Let

X=x1,x2,xnRn×f

represent a f-dimensional point cloud with n sample points. The point cloud can be transformed to a directed graph

G=E,V

, where

E

and

V

denote edges and vertices, respectively. The EdgeConv uses a k-NN graph to capture the edge feature eij between each point xi and its k nearest neighbors

xj:(i,j)E

. Mathematically, in our network, eij is computed as

eij=ReLUθmxjxi+φmxi,(4)

where

θ1,..θm,ϕ1,..ϕm

are learnable parameters, ReLU is the ReLU activation function, and features

xi,xjR15

. Compared with the vanilla EdgeConv block, our EdgeConv further applies channel-wise symmetric aggregation with both max-pooling and mean-pooling to better fuse neighbor information. This modification could help learn more expressive local and global representations. The architecture of our EdgeConv can be described as: a k-nn graph for edge feature extraction, three consecutive Conv2D layers with 64 filters, and a max/mean-pooling layer. In particular, we set k = 30 for the segmentation network, and we employ three stacked EdgeConv blocks to extract the hierarchical representations and concatenate each layer by another Conv2D layer and max-pooling. Furthermore, to include the category information about mandible and maxillary, we feed a one-hot categorical vector into a Conv2D layer, and the output is concatenated with the corresponding point representations learned from stacked EdgeConv blocks. The detailed settings of each layer is illustrated in Figure 2.

The overall schema and working mechanism of our tooth segmentation network can be summarized as follows. The input mesh is transformed into a 3D point cloud with a shape of

R10000×15

, which will be sent to the transform net to aligned into a canonical space. Afterward, three EdgeConv are stacked, consecutively, to extract hierarchical point representations, i.e., the bottom EdgeConv block can learn local geometrical features, while the top EdgeConv block can learn high-level semantic features. Such hierarchical representation learning is achieved because we compute the k nearest neighbors based on the outputs from proceeding layers of EdgeConv blocks. Hence, the proximity is dynamically updated from layer to layer, leading to extracted local geometrical features in bottom EdgeConv blocks and high-level semantic features in top EdgeConv blocks. These hierarchical features are gathered by a Conv2D and max-pooling layer to obtain pointwise features fused with both local and global representations. After encoding the one-hot category information, the concatenated outputs will be fed into a series of convolution layers and dropout layers to learn global features and generate the final segmentation results

ŷiR10000×33

, where 33 is the number of classes in our segmentation task.

4.2.3 Loss function

We use the cross-entropy loss to train the network. To avoid overfitting, we further add an L2 regularization term in addition to the cross-entropy loss. The loss function is defined as follows:

L=1Ni=1NyilnSoftmaxŷiCross−Entropy−Loss+λυυ2L2Loss,(5)

where N denotes the number of points in the point cloud, yi and

ŷi

represent the true label and predicted label of i-th point, respectively. υ denotes the weight parameters in the trained model , and λ is a hyperparameter to adjust the significance of the L2 regularization term.

4.3 Encryption Authority

The encryption authority is a third-party agency that provides public and private keys for clients participating in federated learning. It is necessary that the encryption authority is independent of the server, i.e., the server should not obtain the private key for decryption. Here, we introduce homomorphic encryption (HE) in our federated learning framework. HE is an encryption method that allows the mathematical operation to be performed on the ciphertext, and the result of the operation after decryption is consistent with that of the direct operation on the plain text. Rivest et al. (1978) in their study raise the concept of homomorphic encryption. From the perspective of mathematical operation, it mainly consists of two branches: fully homomorphic encryption (FHE) and partial homomorphic encryption (PHE). The PHE is further divided into two parts: addition homomorphism and multiplication homomorphism. The FHE algorithm supports both multiplication and addition. Consequently, FHE would request a lot more computation resources than PHE. Normally an HE algorithm contains three parts: key generation, encryption, and decryption, denoted by Keygen, Enc(), and Dec(), respectively. Keygen will generate the public key Kp for encryption and private key Kv for decryption. Let m1, m2 denote the plain text, HE has the following property:

Encm1,KpEncm2,Kp=Encm1m2,(6)

where ⊗ and ⊕ represent mathematical operators. If the encryption method satisfies Eq. 6, it conforms the homomorphism on ⊕ operation, such as addition or multiplication. Recent work by Kaissis et al. (2020) and Shah et al. (2021) has shown popularity and feasibility to include addition homomorphism in distributed learning. In our encryption authority, we employ the Paillier algorithm, which is an addition homomorphism invented by Paillier, (1999): Enc(m1, Kp)⋅Enc(m2, Kp) = Enc(m1 + m2).

During the communication, each client encrypts their local model parameters with public key requesting from the encryption authority, i.e.,

ζtkEnc(ωtk,Kp)

. Upon gathering all clients’ uploaded messages, the server will perform aggregation on the messages as shown in Eq. 3, and then distribute the updated message

ζt+1k

and total size of training data

N

to each client. The updated model parameters are calculated as:

ωt+1kDec(ζt+1k,Kv)N

.

4.3.1 Paillier Cryptosystem

Paillier cryptosystem is a probabilistic public key encryption system raised by Paillier in 1999. It supports addition homomorphism, i.e., the addition of two ciphertexts. It also supports ciphertext multiplied by a non-negative integer plaintext. The algorithm of key generation, encryption, and decryption is presented in Algorithm 1

Based on FedAvg McMahan et al. (2017), we established our privacy-preserving framework with a homomorphic encryption mechanism. Assume there are n clients collaboratively contributing to a global model, the objective is to minimize the global loss function:

arg minωk=1nαkLkωk,(7)
www.frontiersin.org

Algorithm 1. Paillier cryptosystem.

www.frontiersin.org

Algorithm 2. FedTSeg.

where

αk=NkN

and

k=1nαk=1

. Lk () is the local loss function for client k defined in Eq. 1. In our framework, we let all participants update their local model for a certain training period E, where they use the stochastic gradient descent (SGD) method to calculate weight updates in each round. Finally, the overall algorithm for our FedTSeg is presented in Algorithm 2. Key notations are present in Table 1

www.frontiersin.org

TABLE 1. Symbol notations.

5 Experiment

5.1 Experimental Setup

5.1.1 Dataset and Preprocessing

We collect 3,000 IOS meshes with labels annotated by experts. Each scan exhibits a 3D mesh for a patient’s mandible or maxillary with corresponding labels for each face denoting the tooth and gingiva. We randomly split the dataset into 80% for the training set and 20% for the testing set. The training set will be distributed to clients following different experimental settings, while the testing set is the same for each client. The training samples are randomly transformed to get multiple augmentations to achieve better generalization ability. In particular, the transformation is defined as

T=R|S|τ

, where

R,S,τ

denote the rotation, scaling, and translation, respectively. For rotation

R

, the mesh will rotate around an arbitrary axis with a random degree chosen from [0°, 5°, 10°]. The scaling

S

will scale the mesh with a factor randomly chosen from (83%, 115%). We will apply a global translation to the mesh under the 3D Cartesian coordinates. Each data sample would be augmented four times, and the corresponding features are extracted following the procedure given in Section 4.1.1.

5.1.2 Implementation Details

As shown in Figure 2, the network architecture is summarized as: Input mesh → Point cloud

xR10000×15

Transform Net → EdgeConv → EdgeConv → EdgeConv → Conv2D [1024] → Maxpool → Conv2D[256] → Dropout → Conv2D[256] → Dropout → Conv2D[128] → Output. The number in the bracket denotes the number of filters in the 2D convolutional layer. Unless otherwise indicated, all the Conv2D layers use a kernel size of 1 × 1 and a stride of 1 with batch normalization and ReLU activation. We set k = 30 for k-nn graph in EdgeConv block. The dropout ratio is 0.6. The output layer is a Conv2d[33] without batch normalization and ReLU activation. The network is trained with a batch size of 4. During local training, each client uses an Adam optimizer for gradient descent, with a basic learning rate η = 0.002, an exponential decay rate of 0.7, and decay steps of 16,0000. To keep the fairness of performance evaluation, we measure all models with the same testing set. All experiments are conducted using a GPU of NVIDIA GeForce RTX 3090 (24G) and Intel(R) Xeon(R) Gold 6139M CPU @ 2.30GHz, including but not limited to training, testing, encryption/decryption process, and inference.

5.1.3 Federated Setting

In our experiments, we simulate a real-world situation where different medical institutions participate collaboratively in 3D tooth segmentation. We considered two distributions: “balanced” and “imbalanced.” “Balanced” distribution means the size of the training set for each client is the same, and “imbalanced” distribution means the sizes are not necessarily the same. Let E denote local training epochs between each communication round. Unless specified, we set E = 20, i.e., for every 20 epochs, each client should send their local model’s parameter to the server.

5.1.4 Metrics

In performance analysis, we include three metrics to evaluate the segmentation performance measuring the performance of models: mean intersection over union (mIoU), dice coefficient (DSC), and pointwise classification accuracy (Acc). Let X denote the segmentation results and Y denote ground truths, then

DSC=2XY|X|+|Y|

,

mIoU=XY|XY|

.

5.2 Results

5.2.1 FL vs. w/o FL

We first compare the performance between global/local models with FedTSeg (FL) and average/local models without FedTSeg (w/o FL). We simulate five independent clients under balanced distribution (c_1, c_2, ..c_5). Each client uses 20% of the original training set to carry on local training. The result is shown in Table 2. As for the global model with FedTSeg and the average performance of five local models without FL, the global model of FedTSeg reaches 80.63% mIoU, 86.03% DSC, and 91.39% Acc, outperforming the average model without FL by 10.05% mIoU, 8.24% DSC, and 5.93% Acc. Such a large performance margin demonstrates the effectiveness and necessity of employing federated learning for 3D tooth segmentation.

www.frontiersin.org

TABLE 2. Performance of five clients’ local models when epoch = 100. “Size” means the proportion of the dataset for each client; w/FL means clients are within the FedTSeg framework; w/o FL means clients perform purely local training; global denotes the global model after aggregation; and avg calculates the average testing accuracy of five local models without the FL framework.

As for the local models, we can notice that, without FedTSeg, c_3’s local model gets the highest performance with Acc 90.68%, DSC 85.10%, and mIoU 79.63%, while c_5’s local model gets a terrible performance with Acc 75.30%, DSC 65.26%, and mIoU 56.88%. Though the sizes of training sets are the same, there is still a large variance among their final results, which indicates that the data heterogeneity significantly influences the local model performance. The global model of FedTSeg, trained with the aggregated dataset, significantly outperforms all local models without FedTSeg. Besides, there is a substantial decrease in the variance among five clients with FedTSeg. In particular, the standard deviation of Acc decreases from 5.27% (without FL) to 1.35% (with FL), indicating that the influence of data heterogeneity is materially alleviated by FedTSeg.

We further conduct experiments with five clients under the imbalanced distribution, i.e., the clients are denoted as (c_1′, c_2′, …, c_5′). As shown in the bottom part of Table 2, c_1′ and c_2′ hold the smallest training set with a size of 8%, and c_5′ holds the largest training set with a size of 34%, simulating the scenarios of small clinics and large hospitals in the real world. Without FedTSeg, the avg model under imbalanced distribution is worse than that under balanced distribution by 11.75% mIoU, 12.21% DSC, and 8.46% Acc. Since the size of the training set is too small, c_1′‘s model only gets a result of 64.64% Acc, 46.46% DSC, and 40.45% mIoU, while c_2′‘s model gets a better result of 84.17% Acc, 78.33% DSC, and 70.49% mIoU. Though using the same size of the training set, there still exists a large gap in performance between c_1′ and c_2′, indicating a significant influence of data heterogeneity. The variance of performance grows larger than that of balanced distribution with a standard deviation increase from 5.27 to 9.08% due to data heterogeneity and varying amounts of local training data under imbalanced distribution.

But with FedTSeg, the performance of local and global five-client models is significantly boosted, e.g., c_1′‘s Acc increases from 64.64 to 89.62%. We can notice that all five clients’ local models improve their performance while maintaining a smaller variance, as compared to the local model trained without FedTSeg. Moreover, the global model outperforms all local models, including those with FedTSeg. Thus, clients could obtain a better global model with FedTSeg, especially for those who only possess a small amount of training data.

We also study the training process and investigate the convergence properties of network training with or without FedTSeg under imbalanced distribution. The training process of the five-client setting is illustrated in Figure 3A. We can see that the convergence of clients c_1′ and c_2′ during training is relatively slower than client c_5′, which possesses a larger size than the local dataset. The convergence performance is positively correlated with local data size, while our FedTSeg could break the limit and boost the convergence of training, i.e., all five clients’ convergence speed within FedTSeg is faster than clients with pure local training. Moreover, the clients exhibit a relatively smaller fluctuation of training accuracy during learning, as demonstrated by the smoother lines in Figure 3A.

www.frontiersin.org

FIGURE 3. Convergence analysis of each experiment. (A) Demonstration of the convergence of five clients under imbalanced distribution within FedTSeg (solid line) and without FedTSeg (dash) for 100 epochs. (B) Demonstration of the convergence of different clients and distributions with FedTSeg for 100 epochs. (C) Extension of the five-client experiment and 15-client experiment in Table 3, to 260 epochs.

Lastly, we trained a model under the centralized learning (CL) paradigm where we assumed that all data samples could be aggregated and available during learning. The performance of this CL model could be regarded as the upper bound of our 3D tooth segmentation task. The results are reported in Table 2. We can notice that the performance of the FedTSeg framework is on par with CL, i.e., CL only slightly outperforms the global model with FedTSeg by 1.07% mIoU, 0.55% DSC, and 0.73% Acc. Compared to the large margin between models with or without FedTSeg, such an improvement is relatively smaller, demonstrating the effectiveness of our FedTSeg model when data sharing is prohibited due to privacy concerns.

5.2.2 More Clients in Imbalanced and Balanced Distribution

In real-world scenarios, we suppose there are more clients participating in the FedTSeg framework. Thus, we simulate FedTSeg under balanced and imbalanced distribution with more clients, i.e., 15 and 20 clients. The detailed distribution of the local training dataset and the testing accuracy of each client’s local model with FedTSeg under imbalanced distribution are displayed in Figure 4. The performance of the global model with different numbers of clients is reported in Table 3.

www.frontiersin.org

FIGURE 4. Details of imbalanced distribution for local training set size and local model’s test accuracy at Epoch = 100 before aggregation.

www.frontiersin.org

TABLE 3. Performance of imbalanced and balanced distribution on different clients’ number at epoch = 100.

We can notice that the global model under imbalanced distribution gets better performance than that of balanced distribution. To better understand such an effect, we record the convergence performance of experiments related to Table 3 in Figure 3B. For reference, we also add the convergence curve of the centralized learning paradigm. We can see that the convergence under imbalanced distribution is slightly faster than that under balanced distribution. Hence, we conjectured that the clients with large-scale high-quality data predominate during parameter updating, which could greatly contribute to an excellent global model, especially in the early stage of training. Such a conjecture is empirically demonstrated in Figure 3C; Table 4, i.e., the gap of Acc between the five-client and 15-client global model shrinks from 2.77% (trained for 100 epochs) to 1.38% (trained for 260 epochs). In particular, though the performance of the global model with five clients is saturated when trained 100 epochs, the performance of the global model with 15 clients still can be slowly improved if the training epoch is increased. This is reasonable because the global model needs to be progressively improved, especially in settings with lots of clients each with a smaller number of local training samples.

www.frontiersin.org

TABLE 4. Performance of the FL global model for 5 clients and 15 clients at epoch = 100 and epoch = 260.

We can also notice that both global models of 15 clients and 20 clients outperform the local models under balanced and imbalanced distribution, which is consistent with the above results of the five-client setting, demonstrating that FedTSeg can gather the advantages of multiple clients and boost the performance of local models. Under imbalanced distribution, the performance of the global model consisting of five clients is better than that of 15 clients by 3.17% mIoU, 1.83% DSC, and 2.77% Acc, and the 15-client global model is better than the 20-client model by 2.19% mIoU, 1.25% DSC, and 1.64% Acc. Under the balanced distribution, the five-client global model is better than 15-client global model by 5.45% mIoU, 3.82% DSC, and 3.27% Acc, and 15-client’s better than 20-client’s by 1.04% mIoU, 0.86% DSC, and 0.36% Acc. Since these three experiments (5, 15, and 20 clients) use the same original training data, the average size of clients would decrease when the number of clients increases, leading to performance degradation in experiments.

5.3 Visualizations

We demonstrate the effectiveness of our method with detailed case visualizations. Five visualization cases of segmentation results under balanced distribution are displayed in Figure 5, where each column represents a specific case. By comparing the segmentation results of models with or without FedTSeg, we can see that the global model of FedTSeg could predict a more precise segmentation result than pure local training under balanced distribution. Concretely, we can notice that the “w/o FL” model would mistakenly recognize part of the lateral incisor as cuspid, as shown in case 1. Without FedTSeg, there are errors of omission at the boundary between the central incisor and gingiva (tooth-gingiva boundary) in cases 2 to 5. More mistakes are also committed for the tooth–tooth boundary, i.e., there are boundary segmentation errors between the second bicuspid and first molar as in cases 2 and 5. As for the segmentation of molars, we can see that the predicted shape of the second molar by “w/o FL” model is visibly different from the ground truth, as shown in case 4. In contrast, the model trained with FedTSeg, though not perfect, rarely commits such mistakes, leading to much better segmentation results that would be more appealing to real-world clinical applications.

www.frontiersin.org

FIGURE 5. Visualization of segmentation results under balanced distribution with FedTSeg framework and without FedTSeg framework. “FL” denotes the global model with FedTSeg and “w/o FL” denotes the local model under pure local training. The differences are annotated with dotted circles.

We visualized another five cases of segmentation results under imbalanced distribution, as shown in Figure 6. In cases 2, 4, and 5, there are obvious mistakes predicting tooth parts of the cuspid and first bicuspid without FedTSeg. In cases 2, 3, 4, and 5, the boundaries among the second molar, first molar, and second bicuspid are not correctly segmented. While with FedTSeg, these errors are significantly solved, which further demonstrates the effectiveness of the FedTSeg framework under imbalanced distribution.

www.frontiersin.org

FIGURE 6. Visualization of segmentation results under imbalanced distribution with FedTSeg and without FedTSeg. Some distinct errors are annotated with dotted circles.

5.4 Communication Analysis

In FedTSeg, each client will encrypt their model parameters and decrypt the ciphertext for each communication round using Paillier algorithm. To quantify the magnitude of external communication cost, we record the training time, encryption time, and decryption time. Specifically, we used Tensorflow to train our network, where the model is stored in “ckpt” format. We used the open-source framework “Python-Paillier” to perform element-wise encryption on the model, i.e., each floating number inside the tensors is encrypted by the public key. For each model, there are 295 tensors with 545,6841 floating points to be encrypted.

The result is reported in Table 5. It takes about 6 min to train one epoch in the five-client setting with the balanced distribution. It takes 36 and 15 min to encrypt a local model and decrypt the upcoming global model using the 256-bit public key and private key. With the increase in key length, the ciphertext is more secure. However, the time cost for encryption and decryption also increases, respectively. What calls for special attention is that the computation ability of GPU/CPU can influence the training and homomorphic time cost, thus the value in Table 5 only gives a referenced value in our experimental setup as described in Section 5.1.

www.frontiersin.org

TABLE 5. Time analysis of the homomorphic encryption process. For training time, we record the average time for training one epoch with 20% data.

6 Discussion

There are nevertheless some limitations of our work. First, there still exists a gap between FL and the centralized paradigm. As shown in Figure 7A, there are some mistakes in annotating small tooth parts in the lateral incisor with FedTSeg, while the segmentation from CL is more precise. Meanwhile, though our FedTSeg can detect the missing unlabeled teeth and avoid manual mistakes, it is still inferior to the CL in extreme cases. Ideally, an applicable federated learning framework should be able to reach the same performance as a centralized learning paradigm or even surpass it. More domain-specific design might be required to better resolve the federated 3D tooth segmentation tasks, e.g., frameworks considering the anatomical and morphological features of different oral diseases.

www.frontiersin.org

FIGURE 7. Visualization of segmentation results under a centralized paradigm and federated learning framework. (A) Comparison between FL and CL. (B) Comparison between FL and GT.

Moreover, the federated learning framework should be able to generalize to settings with a large number of clients efficiently. In our preliminary results, more training time is required to better convergence of FedTSeg with more clients, though we simulate it with fewer samples in each client. It is reasonable that enlarging the local training set might help faster convergence of the FL framework with many clients. But novel designs are also appealing to deal with the scenarios where there are lots of distributed clients with a relatively smaller number of training samples. This is also the case when we have lots of small clinics rather than large hospitals for federated medical image analysis in practice.

Last but not least, further exploration is of high necessity to improve the communication and model aggregation efficiency, though some recent research already shed light on this direction. There exists computational heterogeneity among clients. For those hospitals with stronger computation ability, they would finish training and encryption faster than others, then they would have to wait for others to finish their communication. The cannikin law indicates that our framework would be limited by the slowest client. Asynchronous model aggregation might be adopted for faster convergence. Moreover, the weight parameters of the neural networks could be quantized or pruned to decrease the communication overload for better communication efficiency. Combining federate learning with cutting-edge deep-neural-network-compression techniques is a promising solution to learning powerful models efficiently with lots of mobile-edge devices.

7 Conclusion

In conclusion, we design and develop a federated learning framework FedTSeg for federated 3D IOS tooth segmentation, and achieve privacy-preserving via homomorphic encryption. Comprehensive experimental results reveal that FedTSeg could obtain much better global and local models than conventional local training paradigms, and could achieve segmentation performance on par with the centralized paradigms where we assure all data could be aggregated. Moreover, the data is strictly protected with homomorphic encryption, preventing attacks during parameter exchange. Our future work will focus on handling the performance degradation issue in scenarios with a large number of clients, each associated with a limited number of local training data, which is closely related to the long-tailed learning problem. Further techniques to resolve communication efficiency issues are also highly demanded to help deploy our framework to real-world applications.

Data Availability Statement

The raw data supporting the conclusions of this article will be made available by the authors, without undue reservation.

Author Contributions

ZL initialized the project. SL, HY, and ZL designed the framework. JH and YF prepared the dataset. SL, YT, HY, and ZL implemented the framework and performed statistical analysis. SL and ZL wrote the manuscript. All authors significantly revised and approved the manuscript.

Funding

This work was supported by the National Natural Science Foundation of China (62106222).

Conflict of Interest

YF was employed by the company Angelalign Inc.

The remaining authors declare that the research was conducted in the absence of any commercial or financial relationships that could be construed as a potential conflict of interest.

Publisher’s Note

All claims expressed in this article are solely those of the authors and do not necessarily represent those of their affiliated organizations, or those of the publisher, the editors, and the reviewers. Any product that may be evaluated in this article, or claim that may be made by its manufacturer, is not guaranteed or endorsed by the publisher.

Acknowledgments

The authors thank Xiang Li, Mingzhou Wu, and Chenxi Liu for their helpful discussion and initial work on this study.

References


Disclaimer:

This article is autogenerated using RSS feeds and has not been created or edited by OA JF.

Click here for Source link (https://www.frontiersin.org/)