Skip to content

Commit

Permalink
repo sync 240103 (secretflow#1114)
Browse files Browse the repository at this point in the history
* repo-sync-2024-01-03T11:58:59+0800

* repo-sync-2024-01-03T12:14:31+0800

* revert yacl
  • Loading branch information
ian-huu authored Jan 3, 2024
1 parent c1d8fd4 commit df4b2e7
Show file tree
Hide file tree
Showing 96 changed files with 5,851 additions and 654 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,4 @@ selected-keys.*
*Miniconda3.sh

# pytest-testmon db
.testmondata
.testmondata*
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`Fixed` for any bug fixes.
`Security` in case of vulnerabilities.

## [1.4.0.dev240103] - 2024-1-3
### Added
- Add grad replace attack and replay attack.
- Add autoattack benchmark examples.
- Expose job_name param in sf.init.

### Changed
- Bump rayfed version: optimizing Error Propagation and Capture.
- Component: woe_bins requires at least 5 bins to read.
- Component: add barrier_on_shutdown as sf cluster config.

## [1.4.0.dev231225] - 2023-12-25
### Added
- Add DataProxy binary writer.
Expand Down
78 changes: 78 additions & 0 deletions benchmark_examples/autoattack/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# 联邦攻防Benchmark

联邦攻防框架提供了自动调优工具(secretflow tuner),除了可以完成传统的automl能力,
还配合联邦学习的callback机制,实现攻防的自动调优。

用户实现攻击算法后,可以便捷地通过攻防框架,调整找到最合适的攻击参数、模型/数据拆分方法等,
可以借此来判断联邦算法的安全性。

在联邦算法的基础上,我们在几个经典的数据集+模型上,分别实现了几个攻击算法,并实现了benchmark,获得调优的结果。
目前支持的benchmark包括:

| | datasets | models | lia | fia | replay | replace | exploit | norm |
|---:|:----------|:---------|:----|:----|:-------|:--------|:--------|:-----|
| 0 | bank | dnn || | - | - | - ||
| 1 | bank | deepfm | | | - | - | - ||
| 2 | bank | resnet18 | - | - | - | - | - | - |
| 3 | bank | vgg16 | - | - | - | - | - | - |
| 4 | bank | resnet20 | - | - | - | - | - | - |
| 5 | movielens | dnn | | | - | - | - | - |
| 6 | movielens | deepfm | | | - | - | - | - |
| 7 | movielens | resnet18 | - | - | - | - | - | - |
| 8 | movielens | vgg16 | - | - | - | - | - | - |
| 9 | movielens | resnet20 | - | - | - | - | - | - |
| 10 | drive | dnn | - || - | - | - | - |
| 11 | drive | deepfm | - | - | - | - | - | - |
| 12 | drive | resnet18 | - | - | - | - | - | - |
| 13 | drive | vgg16 | - | - | - | - | - | - |
| 14 | drive | resnet20 | - | - | - | - | - | - |
| 15 | criteo | dnn | - | - | - | - | - ||
| 16 | criteo | deepfm | - | - | - | - | - ||
| 17 | criteo | resnet18 | - | - | - | - | - | - |
| 18 | criteo | vgg16 | - | - | - | - | - | - |
| 19 | criteo | resnet20 | - | - | - | - | - | - |
| 20 | mnist | dnn | - | - | - | - | - | - |
| 21 | mnist | deepfm | - | - | - | - | - | - |
| 22 | mnist | resnet18 | - | - | - | - | - | - |
| 23 | mnist | vgg16 | - | - | - | - | - | - |
| 24 | mnist | resnet20 | - | - | - | - | - | - |
| 25 | cifar10 | dnn | - | - | - | - | - | - |
| 26 | cifar10 | deepfm | - | - | - | - | - | - |
| 27 | cifar10 | resnet18 || - | - | - | - | - |
| 28 | cifar10 | vgg16 | - | - | - | - | - | - |
| 29 | cifar10 | resnet20 || - | - | - | - | - |

## 如何添加新的实现

代码在`benchmark_example/autoattack`目录下。

`applications`目录下为具体的数据集+模型实现。
其中目录结构为`数据集分类/数据集名称/模型名称/具体实现`,例如`image/cifar10/vgg16`

`attacks`目录下为具体的攻击实现。
其中编写的是通用的攻击代码,如果攻击依赖于具体的数据集,例如需要辅助数据集和辅助模型,则在application下的数据+模型代码中,提供这些代码。

ps:需要在`autoattack/utils/distribution.py`中添加新的实现,以便能够检索到。


## 运行单条测试

```shell
cd secretflow
# 训练
python benchmark_example/autoattack/main.py bank dnn train
# 攻击
python benchmark_example/autoattack/main.py bank dnn lia
# auto攻击
python benchmark_example/autoattack/main.py bank dnn auto_lia
```

## 运行benchmark

```shell
cd secretflow
# 训练
python benchmark_example/autoattack/benchmark.py train
# auto
python benchmark_example/autoattack/benchmark.py auto
```
13 changes: 13 additions & 0 deletions benchmark_examples/autoattack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions benchmark_examples/autoattack/applications/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
94 changes: 94 additions & 0 deletions benchmark_examples/autoattack/applications/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union

from secretflow.ml.nn.callbacks.callback import Callback


class TrainBase(ABC):
def __init__(
self,
config: Dict,
alice,
bob,
device_y,
num_classes,
epoch=2,
train_batch_size=128,
):
self.config = config
self.alice = alice
self.epoch = config.get('epoch', epoch)
self.train_batch_size = config.get('train_batch_size', train_batch_size)
self.bob = bob
self.device_y = device_y
self.num_classes = num_classes
self.config = config
(
self.train_data,
self.train_label,
self.test_data,
self.test_label,
) = self._prepare_data()
self.alice_base_model = self._create_base_model_alice()
self.bob_base_model = self._create_base_model_bob()
self.fuse_model = self._create_fuse_model()

@abstractmethod
def train(self, callbacks: Optional[Union[List[Callback], Callback]] = None):
pass

def predict(self):
raise NotImplementedError("Predict not implemented.")

@abstractmethod
def _prepare_data(self):
pass

@abstractmethod
def _create_base_model_alice(self):
pass

@abstractmethod
def _create_base_model_bob(self):
pass

@abstractmethod
def _create_fuse_model(self):
pass

def support_attacks(self):
"""
Which attacks this application supports.
Returns:
List of attack names, default is empty.
"""
return []

def lia_auxiliary_data_builder(self, batch_size=16, file_path=None):
raise NotImplementedError(
f"need implement lia_auxiliary_data_builder on {type(self).__name__} "
)

def lia_auxiliary_model(self, ema=False):
raise NotImplementedError(
f"need implement lia_auxiliary_model on {type(self).__name__} "
)

def fia_auxiliary_data_builder(self):
raise NotImplementedError(
f"need implement fia_auxiliary_data_builder on {type(self).__name__} "
)
13 changes: 13 additions & 0 deletions benchmark_examples/autoattack/applications/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC
from typing import List, Optional, Union

import torch
from torchvision import datasets, transforms

from benchmark_examples.autoattack.applications.base import TrainBase
from secretflow.ml.nn import SLModel
from secretflow.ml.nn.callbacks.callback import Callback

from .data_utils import CIFAR10Labeled, CIFAR10Unlabeled, label_index_split


class Cifar10TrainBase(TrainBase, ABC):
def __init__(self, config, alice, bob, epoch=1, train_batch_size=128):
super().__init__(
config, alice, bob, bob, 10, epoch=epoch, train_batch_size=train_batch_size
)

def train(self, callbacks: Optional[Union[List[Callback], Callback]] = None):
sl_model = SLModel(
base_model_dict={
self.alice: self.alice_base_model,
self.bob: self.bob_base_model,
},
device_y=self.device_y,
model_fuse=self.fuse_model,
simulation=True,
random_seed=1234,
backend='torch',
strategy='split_nn',
)
history = sl_model.fit(
x=self.train_data,
y=self.train_label,
validation_data=(self.test_data, self.test_label),
epochs=self.epoch,
batch_size=self.train_batch_size,
shuffle=False,
random_seed=1234,
dataset_builder=None,
callbacks=callbacks,
)
logging.warning(history)

def _prepare_data(self):
from secretflow.utils.simulation import datasets

(train_data, train_label), (test_data, test_label) = datasets.load_cifar10(
[self.alice, self.bob],
)

return train_data, train_label, test_data, test_label

def lia_auxiliary_data_builder(
self, batch_size=16, file_path="~/.secretflow/datasets/cifar10"
):
def prepare_data():
n_labeled = 40
num_classes = 10

def get_transforms():
transform_ = transforms.Compose(
[
transforms.ToTensor(),
]
)
return transform_

transforms_ = get_transforms()

base_dataset = datasets.CIFAR10(file_path, train=True)

train_labeled_idxs, train_unlabeled_idxs = label_index_split(
base_dataset.targets, int(n_labeled / num_classes), num_classes
)
train_labeled_dataset = CIFAR10Labeled(
file_path, train_labeled_idxs, train=True, transform=transforms_
)
train_unlabeled_dataset = CIFAR10Unlabeled(
file_path, train_unlabeled_idxs, train=True, transform=transforms_
)
train_complete_dataset = CIFAR10Labeled(
file_path, None, train=True, transform=transforms_
)
test_dataset = CIFAR10Labeled(
file_path, train=False, transform=transforms_, download=True
)
print(
"#Labeled:",
len(train_labeled_idxs),
"#Unlabeled:",
len(train_unlabeled_idxs),
)

labeled_trainloader = torch.utils.data.DataLoader(
train_labeled_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=True,
)
unlabeled_trainloader = torch.utils.data.DataLoader(
train_unlabeled_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
drop_last=True,
)
dataset_bs = batch_size * 10
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=dataset_bs, shuffle=False, num_workers=0
)
train_complete_trainloader = torch.utils.data.DataLoader(
train_complete_dataset,
batch_size=dataset_bs,
shuffle=False,
num_workers=0,
drop_last=True,
)
return (
labeled_trainloader,
unlabeled_trainloader,
test_loader,
train_complete_trainloader,
)

return prepare_data
Loading

0 comments on commit df4b2e7

Please sign in to comment.