Author Login Chief Editor Login Reviewer Login Editor Login Remote Office

Computer Engineering

   

Data Augmentation for Federated Image Classification Using Diffusion Models

  

  • Published:2026-04-28

基于扩散模型的联邦学习数据增强方法

Abstract: Federated learning is a distributed machine learning paradigm that leverages decentralized data resources while ensuring data privacy. However, in real-world scenarios, data across clients are often non-IID (Independent and Identically Distributed), leading to label shift and class imbalance issues, which hinder convergence of global models and degrade generalization performance. To address the impact of such data heterogeneity on model performance, we propose a cross-client data augmentation and classification framework based on diffusion models. In this framework, each client trains an initial diffusion model based on local data and uploads its model parameters to the server. The server aggregates these parameters to construct a global diffusion model, which is then downlinked to all clients. Clients use the global diffusion model to generate supplementary samples, which are uploaded to the server for data augmentation to balance the local class distribution, thereby improving classifier performance. Ultimately, the classification model is trained through federated learning by receiving both local data and generated samples, and is deployed to clients for image classification and recognition. To generate high-quality images, a denoising diffusion probabilistic model is used as the generation backbone, while a ResNet-18 architecture is employed for the federated classification model. Experimental results show that the fine-tuned global diffusion model can generate images that are more consistent with the real data distribution. By augmenting the data through generated samples, the local data distribution on clients becomes more balanced, significantly improving global classification accuracy. Under the non-IID condition with a Dirichlet coefficient α=0.1, the accuracy of CIFAR-10 and CIFAR-100 increased from 46.76% and 21.31% to 54.64% and 25.57%, respectively, demonstrating the effectiveness of the proposed data augmentation strategy in mitigating class imbalance.

摘要: 联邦学习是一种在保障数据隐私的前提下充分利用分散数据资源的分布式机器学习范式。然而,现实场景中各客户端数据通常呈现非独立同分布,存在标签偏移和类别不均衡问题,导致全局模型难以收敛且泛化能力下降。为缓解此类数据异构性对模型性能的影响,提出了一种基于扩散模型的跨客户端数据增强与分类框架。在该框架中,每个客户端基于本地数据训练初始扩散模型,并将模型参数上传至服务器。服务器聚合这些参数以构建全局扩散模型,并下发至各客户端。客户端利用全局扩散模型生成补充样本并上传至服务器,进行数据增强以平衡本地类别分布,从而提升分类器性能。最终,分类模型在接收本地数据和生成样本后通过联邦学习进行训练,并部署至客户端进行图像分类与识别。为了生成高质量图像,采用去噪扩散概率模型作为生成主干,同时使用ResNet-18构建联邦分类模型。实验表明,微调后的全局扩散模型能够生成与真实数据分布更加一致的图像,通过生成样本进行数据增广的策略使客户端本地数据分布更加均衡,进而显著提升全局分类准确率。在Dirichlet系数α=0.1的非独立同分布条件下,CIFAR-10与CIFAR-100的准确率分别从46.76%和21.31%提升至54.64%和25.57%,验证了所提出的数据增强策略在缓解数据不均衡问题中的有效性。