forked from secretflow/secretflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* repo-sync-2024-01-03T11:58:59+0800 * repo-sync-2024-01-03T12:14:31+0800 * revert yacl
- Loading branch information
Showing
96 changed files
with
5,851 additions
and
654 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -352,4 +352,4 @@ selected-keys.* | |
*Miniconda3.sh | ||
|
||
# pytest-testmon db | ||
.testmondata | ||
.testmondata* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# 联邦攻防Benchmark | ||
|
||
联邦攻防框架提供了自动调优工具(secretflow tuner),除了可以完成传统的automl能力, | ||
还配合联邦学习的callback机制,实现攻防的自动调优。 | ||
|
||
用户实现攻击算法后,可以便捷地通过攻防框架,调整找到最合适的攻击参数、模型/数据拆分方法等, | ||
可以借此来判断联邦算法的安全性。 | ||
|
||
在联邦算法的基础上,我们在几个经典的数据集+模型上,分别实现了几个攻击算法,并实现了benchmark,获得调优的结果。 | ||
目前支持的benchmark包括: | ||
|
||
| | datasets | models | lia | fia | replay | replace | exploit | norm | | ||
|---:|:----------|:---------|:----|:----|:-------|:--------|:--------|:-----| | ||
| 0 | bank | dnn | ✅ | | - | - | - | ✅ | | ||
| 1 | bank | deepfm | | | - | - | - | ✅ | | ||
| 2 | bank | resnet18 | - | - | - | - | - | - | | ||
| 3 | bank | vgg16 | - | - | - | - | - | - | | ||
| 4 | bank | resnet20 | - | - | - | - | - | - | | ||
| 5 | movielens | dnn | | | - | - | - | - | | ||
| 6 | movielens | deepfm | | | - | - | - | - | | ||
| 7 | movielens | resnet18 | - | - | - | - | - | - | | ||
| 8 | movielens | vgg16 | - | - | - | - | - | - | | ||
| 9 | movielens | resnet20 | - | - | - | - | - | - | | ||
| 10 | drive | dnn | - | ✅ | - | - | - | - | | ||
| 11 | drive | deepfm | - | - | - | - | - | - | | ||
| 12 | drive | resnet18 | - | - | - | - | - | - | | ||
| 13 | drive | vgg16 | - | - | - | - | - | - | | ||
| 14 | drive | resnet20 | - | - | - | - | - | - | | ||
| 15 | criteo | dnn | - | - | - | - | - | ✅ | | ||
| 16 | criteo | deepfm | - | - | - | - | - | ✅ | | ||
| 17 | criteo | resnet18 | - | - | - | - | - | - | | ||
| 18 | criteo | vgg16 | - | - | - | - | - | - | | ||
| 19 | criteo | resnet20 | - | - | - | - | - | - | | ||
| 20 | mnist | dnn | - | - | - | - | - | - | | ||
| 21 | mnist | deepfm | - | - | - | - | - | - | | ||
| 22 | mnist | resnet18 | - | - | - | - | - | - | | ||
| 23 | mnist | vgg16 | - | - | - | - | - | - | | ||
| 24 | mnist | resnet20 | - | - | - | - | - | - | | ||
| 25 | cifar10 | dnn | - | - | - | - | - | - | | ||
| 26 | cifar10 | deepfm | - | - | - | - | - | - | | ||
| 27 | cifar10 | resnet18 | ✅ | - | - | - | - | - | | ||
| 28 | cifar10 | vgg16 | - | - | - | - | - | - | | ||
| 29 | cifar10 | resnet20 | ✅ | - | - | - | - | - | | ||
|
||
## 如何添加新的实现 | ||
|
||
代码在`benchmark_example/autoattack`目录下。 | ||
|
||
`applications`目录下为具体的数据集+模型实现。 | ||
其中目录结构为`数据集分类/数据集名称/模型名称/具体实现`,例如`image/cifar10/vgg16`。 | ||
|
||
`attacks`目录下为具体的攻击实现。 | ||
其中编写的是通用的攻击代码,如果攻击依赖于具体的数据集,例如需要辅助数据集和辅助模型,则在application下的数据+模型代码中,提供这些代码。 | ||
|
||
ps:需要在`autoattack/utils/distribution.py`中添加新的实现,以便能够检索到。 | ||
|
||
|
||
## 运行单条测试 | ||
|
||
```shell | ||
cd secretflow | ||
# 训练 | ||
python benchmark_example/autoattack/main.py bank dnn train | ||
# 攻击 | ||
python benchmark_example/autoattack/main.py bank dnn lia | ||
# auto攻击 | ||
python benchmark_example/autoattack/main.py bank dnn auto_lia | ||
``` | ||
|
||
## 运行benchmark | ||
|
||
```shell | ||
cd secretflow | ||
# 训练 | ||
python benchmark_example/autoattack/benchmark.py train | ||
# auto | ||
python benchmark_example/autoattack/benchmark.py auto | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Dict, List, Optional, Union | ||
|
||
from secretflow.ml.nn.callbacks.callback import Callback | ||
|
||
|
||
class TrainBase(ABC): | ||
def __init__( | ||
self, | ||
config: Dict, | ||
alice, | ||
bob, | ||
device_y, | ||
num_classes, | ||
epoch=2, | ||
train_batch_size=128, | ||
): | ||
self.config = config | ||
self.alice = alice | ||
self.epoch = config.get('epoch', epoch) | ||
self.train_batch_size = config.get('train_batch_size', train_batch_size) | ||
self.bob = bob | ||
self.device_y = device_y | ||
self.num_classes = num_classes | ||
self.config = config | ||
( | ||
self.train_data, | ||
self.train_label, | ||
self.test_data, | ||
self.test_label, | ||
) = self._prepare_data() | ||
self.alice_base_model = self._create_base_model_alice() | ||
self.bob_base_model = self._create_base_model_bob() | ||
self.fuse_model = self._create_fuse_model() | ||
|
||
@abstractmethod | ||
def train(self, callbacks: Optional[Union[List[Callback], Callback]] = None): | ||
pass | ||
|
||
def predict(self): | ||
raise NotImplementedError("Predict not implemented.") | ||
|
||
@abstractmethod | ||
def _prepare_data(self): | ||
pass | ||
|
||
@abstractmethod | ||
def _create_base_model_alice(self): | ||
pass | ||
|
||
@abstractmethod | ||
def _create_base_model_bob(self): | ||
pass | ||
|
||
@abstractmethod | ||
def _create_fuse_model(self): | ||
pass | ||
|
||
def support_attacks(self): | ||
""" | ||
Which attacks this application supports. | ||
Returns: | ||
List of attack names, default is empty. | ||
""" | ||
return [] | ||
|
||
def lia_auxiliary_data_builder(self, batch_size=16, file_path=None): | ||
raise NotImplementedError( | ||
f"need implement lia_auxiliary_data_builder on {type(self).__name__} " | ||
) | ||
|
||
def lia_auxiliary_model(self, ema=False): | ||
raise NotImplementedError( | ||
f"need implement lia_auxiliary_model on {type(self).__name__} " | ||
) | ||
|
||
def fia_auxiliary_data_builder(self): | ||
raise NotImplementedError( | ||
f"need implement fia_auxiliary_data_builder on {type(self).__name__} " | ||
) |
13 changes: 13 additions & 0 deletions
13
benchmark_examples/autoattack/applications/image/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
13 changes: 13 additions & 0 deletions
13
benchmark_examples/autoattack/applications/image/cifar10/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
143 changes: 143 additions & 0 deletions
143
benchmark_examples/autoattack/applications/image/cifar10/cifar10_base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from abc import ABC | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from torchvision import datasets, transforms | ||
|
||
from benchmark_examples.autoattack.applications.base import TrainBase | ||
from secretflow.ml.nn import SLModel | ||
from secretflow.ml.nn.callbacks.callback import Callback | ||
|
||
from .data_utils import CIFAR10Labeled, CIFAR10Unlabeled, label_index_split | ||
|
||
|
||
class Cifar10TrainBase(TrainBase, ABC): | ||
def __init__(self, config, alice, bob, epoch=1, train_batch_size=128): | ||
super().__init__( | ||
config, alice, bob, bob, 10, epoch=epoch, train_batch_size=train_batch_size | ||
) | ||
|
||
def train(self, callbacks: Optional[Union[List[Callback], Callback]] = None): | ||
sl_model = SLModel( | ||
base_model_dict={ | ||
self.alice: self.alice_base_model, | ||
self.bob: self.bob_base_model, | ||
}, | ||
device_y=self.device_y, | ||
model_fuse=self.fuse_model, | ||
simulation=True, | ||
random_seed=1234, | ||
backend='torch', | ||
strategy='split_nn', | ||
) | ||
history = sl_model.fit( | ||
x=self.train_data, | ||
y=self.train_label, | ||
validation_data=(self.test_data, self.test_label), | ||
epochs=self.epoch, | ||
batch_size=self.train_batch_size, | ||
shuffle=False, | ||
random_seed=1234, | ||
dataset_builder=None, | ||
callbacks=callbacks, | ||
) | ||
logging.warning(history) | ||
|
||
def _prepare_data(self): | ||
from secretflow.utils.simulation import datasets | ||
|
||
(train_data, train_label), (test_data, test_label) = datasets.load_cifar10( | ||
[self.alice, self.bob], | ||
) | ||
|
||
return train_data, train_label, test_data, test_label | ||
|
||
def lia_auxiliary_data_builder( | ||
self, batch_size=16, file_path="~/.secretflow/datasets/cifar10" | ||
): | ||
def prepare_data(): | ||
n_labeled = 40 | ||
num_classes = 10 | ||
|
||
def get_transforms(): | ||
transform_ = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
] | ||
) | ||
return transform_ | ||
|
||
transforms_ = get_transforms() | ||
|
||
base_dataset = datasets.CIFAR10(file_path, train=True) | ||
|
||
train_labeled_idxs, train_unlabeled_idxs = label_index_split( | ||
base_dataset.targets, int(n_labeled / num_classes), num_classes | ||
) | ||
train_labeled_dataset = CIFAR10Labeled( | ||
file_path, train_labeled_idxs, train=True, transform=transforms_ | ||
) | ||
train_unlabeled_dataset = CIFAR10Unlabeled( | ||
file_path, train_unlabeled_idxs, train=True, transform=transforms_ | ||
) | ||
train_complete_dataset = CIFAR10Labeled( | ||
file_path, None, train=True, transform=transforms_ | ||
) | ||
test_dataset = CIFAR10Labeled( | ||
file_path, train=False, transform=transforms_, download=True | ||
) | ||
print( | ||
"#Labeled:", | ||
len(train_labeled_idxs), | ||
"#Unlabeled:", | ||
len(train_unlabeled_idxs), | ||
) | ||
|
||
labeled_trainloader = torch.utils.data.DataLoader( | ||
train_labeled_dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
num_workers=0, | ||
drop_last=True, | ||
) | ||
unlabeled_trainloader = torch.utils.data.DataLoader( | ||
train_unlabeled_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=0, | ||
drop_last=True, | ||
) | ||
dataset_bs = batch_size * 10 | ||
test_loader = torch.utils.data.DataLoader( | ||
test_dataset, batch_size=dataset_bs, shuffle=False, num_workers=0 | ||
) | ||
train_complete_trainloader = torch.utils.data.DataLoader( | ||
train_complete_dataset, | ||
batch_size=dataset_bs, | ||
shuffle=False, | ||
num_workers=0, | ||
drop_last=True, | ||
) | ||
return ( | ||
labeled_trainloader, | ||
unlabeled_trainloader, | ||
test_loader, | ||
train_complete_trainloader, | ||
) | ||
|
||
return prepare_data |
Oops, something went wrong.