Skip to content

Latest commit

 

History

History
1433 lines (1312 loc) · 64.7 KB

readme-zh.md

File metadata and controls

1433 lines (1312 loc) · 64.7 KB

Async-FL

GitHub code size license python torch

This document is also available in: 中文 | English

keywords: federated-learning, asynchronous, synchronous, semi-asynchronous, personalized

目录

初衷

本项目的初衷是我本科毕设期间需要完成搭建一个异步联邦学习框架,并且在其之上完成一些实验。

可当我去github尝试搜索项目时,发现异步联邦学习闭源之深,几乎没有开源项目。并且主流框架也基本不兼容异步,只支持同步FL。因此促生了该项目。

git分支说明

master分支为主分支,代码为最新,但有部分commit是脏commit,不保证每个commit都能正常运行,建议使用打tag(版本号)的version

checkout分支保留了客户端会随着训练过程进行不断加入框架中,主分支已经移除该功能,checkout分支并不维护,只支持同步和异步。

基本配置

python3.8 + pytorch + macos

在linux进行过验证

支持单GPU,尚未进行多GPU优化

运行

实验

直接运行python main.py(fl下的main文件)即可,程序会自动读取根目录下的config.json文件,执行完后将结果储存到results下的指定路径下,并将配置文件一并存储。

也可以自行指定配置文件python main.py ../../config.json,需要注意的是config.json的路径是基于main.py的。

根目录下的config文件夹提供了部分论文提出的算法文件配置,现提供如下算法实现:

FedAvg
FedAsync
FedProx
FedAT
FedLC
FedDL
M-Step AsyncFL

Docker

现在可以直接pull docker镜像进行运行,命令如下:

docker pull desperadoccy/async-fl
docker run -it async-fl config/FedAvg-config.json

类似地,支持传参config文件路径。 也可以自行build

cd docker
docker build -t async-fl .
docker run -it async-fl config/FedAvg-config.json 

特性

  • 异步联邦学习
  • 支持替换模型和数据集
  • 支持替换调度算法
  • 支持替换聚合算法
  • 支持替换loss函数
  • 支持替换客户端
  • 同步联邦学习
  • 半异步联邦学习
  • 提供test loss信息
  • 自定义标签异构
  • 自定义数据异构
  • 支持dirichlet distribution
  • wandb可视化
  • leaf相关数据集支持
  • 支持多GPU
  • docker部署

项目目录

Project Directory

.
├── config                                    常见算法配置
│   ├── FedAT-config.json
│   ├── FedAsync-config.json
│   ├── FedAvg-config.json
│   └── FedProx-config.json
├── config.json                               配置文件
├── config_semi.json                          配置文件
├── config_semi_test.json                     配置文件
├── config_sync.json                          配置文件
├── config_sync_test.json                     配置文件
├── config_test.json                          配置文件
├── doc
│   ├── pic
│   │   ├── fedsemi.png
│   │   ├── framework.png
│   │   └── header.png
│   └── readme-zh.md
├── docker
│   └── Dockerfile
├── license
├── fedsemi.png
├── framework.png
├── readme.md
├── requirements.txt
└── src 
    ├── client                                客户端实现
    │   ├── AsyncClient.py                    异步客户端类
    │   ├── Client.py                         客户端基类
    │   ├── ProxClient.py
    │   ├── SemiClient.py
    │   ├── SyncClient.py                     同步客户端类
    │   └── __init__.py
    ├── data                                  数据集下载位置
    ├── dataset                               数据集类
    │   ├── CIFAR10.py
    │   ├── MNIST.py
    │   ├── FashionMNIST.py
    │   └── __init__.py
    ├── exception                             异常类
    │   ├── ClientSumError.py
    │   └── __init__.py
    ├── fedasync                              异步联邦学习
    │   ├── AsyncClientManager.py             客户端管理类
    │   ├── AsyncServer.py                    异步服务器类
    │   ├── SchedulerThread.py                调度进程
    │   ├── UpdaterThread.py                  聚合进程
    │   └── __init__.py
    ├── fedsemi                               半异步联邦学习
    │   ├── QueueManager.py                   队列管理类
    │   ├── SchedulerThread.py                调度进程
    │   ├── SemiAsyncClientManager.py         客户端管理类
    │   ├── SemiAsyncServer.py                服务器类
    │   ├── UpdaterThread.py                  聚合进程
    │   ├── __init__.py
    │   ├── checker                           半异步检查器
    │   │   └── SemiAvgChecker.py
    │   ├── grouping                          分组(层)器
    │   │   ├── Grouping.py
    │   │   ├── NormalGrouping.py
    │   │   └── SimpleGrouping.py
    │   └── receiver                          半异步接收器
    │       └── SemiAvgReceiver.py
    ├── fedsync                               同步联邦学习
    │   ├── QueueManager.py                   消息队列管理类
    │   ├── SchedulerThread.py                调度进程
    │   ├── SyncClientManager.py              客户端管理类
    │   ├── SyncServer.py                     同步服务器类
    │   ├── UpdaterThread.py                  聚合进程
    │   ├── __init__.py
    │   ├── checker                           同步检查器
    │   │   └── AvgChecker.py
    │   └── receiver                          同步接收器
    │       └── AvgReceiver.py
    ├── fl                                    fl主函数
    │   ├── __init__.py
    │   ├── main.py
    │   └── wandb                             wandb运行文件夹
    ├── loss                                  loss函数实现
    │   └── __init__.py
    ├── model                                 模型类
    │   ├── CNN.py
    │   ├── ConvNet.py
    │   └── __init__.py
    ├── results                               实验结果
    ├── schedule                              调度算法类
    │   ├── FullSchedule.py
    │   ├── RandomSchedule.py
    │   ├── RoundRobin.py
    │   └── __init__.py
    ├── test                                  测试用
    ├── update                                聚合算法类
    │   ├── AsyncAvg.py
    │   ├── FedAT.py
    │   ├── FedAsync.py
    │   ├── FedAvg.py
    │   ├── MyFed.py
    │   └── __init__.py
    └── utils                                 工具集
        ├── ConfigManager.py
        ├── IID.py
        ├── JsonTool.py
        ├── ModuleFindTool.py
        ├── ModelTraining.py
        ├── Plot.py
        ├── ProcessTool.py
        ├── Queue.py
        ├── Random.py
        ├── Time.py
        ├── Tools.py
        └── __init__.py

utils包下的Time文件是一个多线程时间获取类的实现;Queue文件是因为mac的多线程queue部分功能未实现,对queue相关功能的实现。

框架结构

error

error

类解释

接收器类

接收器是同步|半异步联邦学习为了检查该轮全局迭代接收的更新是否满足设置的条件,如所有指定的客户端均已上传更新,满足条件则会触发updater进程进行全局聚合。

检查器类

同步|半异步联邦学习中客户端完成训练后,会将权重上传给检查器类,检查起根据自身逻辑判断是否符合上传标准,选择接收或舍弃该更新。

配置文件

异步配置文件

example

同步配置文件

example

半异步配置文件

example

参数解释

参数解释

参数

类型

说明

wandb

enabled

bool

是否启用wandb

project

string

项目名称

name

string

本次运行的名称

global

use_file_system

bool

是否启用文件系统作为torch多线程共享策略

multi_gpu

bool

是否启用多GPU,详细解释

experiment

string

本次运行的名称

stale

解释

dataset

path

string

数据集所在路径

params

dict

所需变量

iid

解释

client_num

int

客户端数量

server

path

string

server所在路径

epochs

int

全局运行轮次

model

path

string

模型所在路径

params

dict

所需变量

scheduler

path

string

scheduler所在路径

schedule

path

string

schedule所在路径

params

dict

所需变量

other_params

*

其他变量

updater

path

string

updater所在路径

update

path

string

update所在路径

params

dict

所需变量

loss

解释

num_generator

解释

group

path

string

update所在路径

params

dict

所需变量

client_manager

path

string

client manager所在路径

group_manager

path

string

group manager所在路径

group_method

path

string

group method所在路径

params

dict

所需变量

queue_manager

path

string

queue manager所在路径

receiver

path

string

receiver所在路径

params

dict

所需变量

checker

path

string

checker所在路径

params

dict

所需变量

client

path

string

client所在路径

epochs

int

本地运行次数

batch_size

int

batch

model

path

string

模型所在路径

params

dict

所需变量

loss

解释

mu

float

近端项系数

optimizer

path

string

optimizer所在路径

params

dict

所需变量

other_params

*

其他变量

添加新的算法

需要让客户端/服务器调用自己的算法或实现类,(注意:所有的算法实现必须以类的形式),需要以下几个步骤:

  • 在对应的位置加入自己的实现(dataset、model、schedule、update、client、loss)
  • 在对应包的__init__.py文件下导入该类,例如from model import CNN
  • 在配置文件申明,model.path等对应的是新的算法所在路径。
  • checker, group, receiver, schedule, update模块需要在Caller类中补充调用方法
  • loss, numgenerator模块需要在factory类中补充调用方法

另外,算法里需要使用到的参数均可在配置项params中申明。

现在model, optimloss模块均支持引入torch等其他库的实现类,例:

"model": {
      "path": "torchvision.models.resnet18",
      "params": {
        "pretrained": true,
        "num_classes": 10 
      }
}

添加loss

loss函数现由LossFactory类生成创建,可以选择torch自带算法,也可以自行实现: loss支持三种设置,其一是配置文件中普遍使用的string方式:

"loss": "torch.nn.functional.cross_entropy"

程序会直接生成函数式loss。

其二是生成对象式loss:

"loss": {
    "path": "loss.myloss.MyLoss",
    "params": {}
}

其三是根据type生成loss:

"loss": {
        "type": "func",
        "path": "loss.myloss.MyLoss",
        "params": {}
    }

staleness设置

stale支持三种设置,其一是上述配置文件中提到的

"stale": {
      "step": 5,
      "shuffle": true,
      "list": [10, 10, 10, 5, 5, 5, 5]
    }

程序会根据提供的steplist生成一串随机整数,例如上述代码,程序会生成10个0,10个(0,5),10个[5,10)......,并会根据shuffle判断是否进行打乱。最后将随机数串赋给各客户端,客户端根据数值在每轮训练结束后,自动sleep对应秒。在存储json文件至实验结果时,该设置会自动转为其三。

其二是设置为false,程序会给各客户端延迟设置为0。

"stale": false

其三是随机数列表,程序直接会将列表指定延迟设置给客户端。

"stale": [1, 2, 3, 1, 4]

数据分布设置

iid

当iid设置为true时(其实false也是默认为iid),会以iid的方式将数据分配给各客户端。

"iid": true

dirichlet non-iid

iidcustomize设置为false或者不设置时,会以dirichlet分布的方式将数据分配给各客户端。 其中beta是dirichlet分布的参数。

"iid": {
    "customize": false,
    "beta": 0.5
}

或者

"iid": {
    "beta": 0.5
}

customize non-iid

customize non-iid设置分为两部分,一个是标签的non-iid设置,一个是数据量的non-iid设置。目前数据量仅提供随机生成,在未来的版本中将引入个性化设置。 在启用customize设置时,需要将customize设置为true并分别对labeldata进行设置

"iid": {
    "customize": true
}

label distribution

label的设置stale的设置类似,支持三种方式,其一为配置文件中提到的

"label": {
    "step": 1,
    "list": [10, 10, 30]
}

其上配置程序会生成10个拥有1个标签数据的客户端,10个拥有2个标签数据的客户端,30个拥有3个标签数据的客户端 step是标签数量的步长,当step为2时,程序会生成10个拥有1个标签数据的客户端,10个拥有3个标签数据的客户端,30个拥有5个标签数据的客户端

其二为随机数二维数组,程序将二维数组直接设置给客户端

"label": {
    "0": [1, 2, 3, 8],
    "1": [2, 4],
    "2": [4, 7],
    "3": [0, 2, 3, 6, 9],
    "4": [5]
}

其三为一维数组,该一维数组为每个客户端拥有的标签数,该数组长度应和客户端数量一致。

"label": {
  "list": [4, 5, 10, 1, 2, 3, 4]
}

上述配置即客户端0拥有4个标签数据,客户端1拥有5个标签数据...以此类推。

目前label_iid生成的随机化分为两种方法,一种纯随机化,这种情况可能会导致所有客户端均缺少一个标签,导致精度下降(虽然概率极低),另一种方式采用洗牌算法,保证每个标签均会选到,这也会导致无法生成标签分布不均匀的数据情况。洗牌算法的开关由shuffle控制,示例如下:

"label": {
  "shuffle": true,
  "list": [4, 5, 10, 1, 2, 3, 4]
}

data distribution

data的设置比较简单,目前有两种方式,其一为空

"data": {}

也就是不对数据量进行非独立同分布设置。

其二为配置文件中提到的

"data": {
    "max": 500,
    "min": 400
}

也就是说客户端的数据量范围在400-500,程序会自动平均分配到各标签

数据量分布还较初始,之后将会逐步完善

客户端替换

目前客户端替换需要继承AsyncClientSyncClient,新增的参数通过client配置项传入类中。

多GPU

本项目的多GPU特性并不是多GPU并行计算,各客户端训练依旧在单GPU上,但宏观上客户端运行在多个GPU上,也就是每个客户端的训练任务会平均分布到程序可见的GPU上,每个客户端绑定的GPU是在初始化时就指定好的,并不是每轮训练时指定,因此依旧会出现各GPU负载严重不均的可能情况。 该特性通过global下的multi_gpu控制开关。

代码尚存问题

目前框架里面有一个核心问题,客户端和服务器之间的通信使用的是multiprocessing的queue实现的,但是该队列在接收cuda张量后,当其他进程获取该张量,会导致内存溢出,程序异常退出。

这个bug是pytorch和queue导致的bug,暂时采取的解决方法是上传非cuda张量,聚合时再将其转为cuda张量,因此在添加聚合算法时,大致会需要出现如下代码:

updated_parameters = {}
for key, var in client_weights.items():
    updated_parameters[key] = var.clone()
    if torch.cuda.is_available():
        updated_parameters[key] = updated_parameters[key].cuda()

Contributors

desperadoccy
Desperadoccy
jzj007
Jzj007

联系我

QQ: 527707607

邮箱: [email protected]

欢迎对项目提出建议~