diff --git a/.bazeliskrc b/.bazeliskrc new file mode 100644 index 000000000..c23a95566 --- /dev/null +++ b/.bazeliskrc @@ -0,0 +1,2 @@ +USE_BAZEL_VERSION=5.1.1 +BAZELISK_BASE_URL=https://github.com/bazelbuild/bazel/releases/download diff --git a/CHANGELOG.md b/CHANGELOG.md index a35f0deab..957b46a51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,23 +12,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `Fixed` for any bug fixes. `Security` in case of vulnerabilities. -## [1.3.0.dev231211] - 2023-12-11 +## [1.3.0.dev231218] - 2023-12-18 ### Added -- Add IO component including read, write and identity. -- Change groupby component to by-query style. - -## [1.3.0.dev231205] - 2023-12-05 +- Make barrier_on_shutdown optional. +- Support SGB label holder without features. +- Support SL Model training on file data with mutiple labels. +- Add SL ResNet and VGG application. +- Secretflow ic: Add package interconnection protobuf files. +- Component: Add feature calculate component to generate new features by performing calculations on original features. +- Component: Support SGB prediction on big dataset. ### Changed -- Add feature selection in all model predict comps. +- SGB optimize memory usage in prediction. +- Component: Bump groupby statistics version. +- Component: Improve translation. ### Fixed -- Fix pvalue & more readable assert msg. +- Component: Fix woe io and fillna. + + +## [1.3.0.dev231211] - 2023-12-11 +### Added +- Add IO component including read, write and identity. +- Change groupby component to by-query style. + ## [1.3.0.dev231128] - 2023-11-28 ### Added - Add secretflow tuner for automl and autoattack. +- Add IO component including read, write and identity. +- Change groupby component to by-query style. ## [1.3.0.dev231120] - 2023-11-20 diff --git a/REPO_LAYOUT.md b/REPO_LAYOUT.md index 3f3a53719..92e0e0cdd 100644 --- a/REPO_LAYOUT.md +++ b/REPO_LAYOUT.md @@ -1,13 +1,26 @@ # Repository layout -secretflow -- data: horizontal, vertical and mixed DataFrame and Ndarray (like pandas and numpy) -- device: various devices and their kernels, such as PYU, SPU, HEU, etc -- model: federated learning and split learning algorithms -- preprocessing: common utility functions and transformer classes (like scikit-learn) -- security: privacy related algorithms, such as secure aggregation, differential privacy -- util: miscellaneous utility functions +This is a high level overview of how the repository is laid out. Some major folders are listed below: -tests: unit test cases - -docs: documents written in reStructuredText, Markdown, Jupyter-notebook +* [benchmark_examples/](benchmark_examples/): scripts for secretflow component benchmark. +* [docker/](docker/): scripts to build secretflow release and dev docker images. +* [docs/](docs/): documents written in reStructuredText, Markdown, Jupyter-notebook. +* [examples/](examples/): examples of secretflow. +* [secretflow/](secretflow/): the core library. + * [component/](secretflow/component/): secretflow components. + * [compute/](secretflow/compute/): wrapper for pyarrow compute functions. + * [data/](secretflow/data/): horizontal, vertical and mixed DataFrame and Ndarray (like pandas and numpy). + * [device/](secretflow/device/): various devices and their kernels, such as PYU, SPU, HEU, etc. + * [distributed/](secretflow/distributed/): logics related to Ray and RayFed. + * [ic/](secretflow/ic/): interconnection. + * [kuscia/](secretflow/kuscia/): adapter to kuscia. + * [ml/](secretflow/ml/): federated learning and split learning algorithms. + * [preprocessing/](secretflow/preprocessing/): preprocessing functions. + * [protos/](secretflow/protos/): Protocol Buffers messages. + * [security/](secretflow/security/): privacy related algorithms, such as secure aggregation, differential privacy. + * [spec/](secretflow/spec/): generated code of spec Protocol Buffers messages. + * [stats/](secretflow/stats/): statistics functions. + * [tune/](secretflow/tune/): functions related to tuners. + * [utils/](secretflow/utils/): miscellaneous utility functions. +* [secretflow_lib/](secretflow_lib/): some core functions written in C++ and their Python bindings. +* [tests/](tests/): unit tests with pytest. diff --git a/WORKSPACE b/WORKSPACE index b0e508d0e..595dcaa58 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,7 +4,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") git_repository( name = "yacl", - commit = "6ba8bd5f02035176ec4daaca1c1269195a1b1b4e", + commit = "2b7d8882c78f07bd9e78217b7f9ca13135781e65", remote = "https://github.com/secretflow/yacl.git", ) diff --git a/docker/comp_list.json b/docker/comp_list.json index 41e406744..02871f4c9 100644 --- a/docker/comp_list.json +++ b/docker/comp_list.json @@ -252,9 +252,8 @@ "inputs": [ { "name": "input_data", - "desc": "Input dist data", + "desc": "Input data", "types": [ - "sf.model.ss_glm", "sf.model.ss_glm", "sf.model.sgb", "sf.model.ss_xgb", @@ -268,9 +267,8 @@ "outputs": [ { "name": "output_data", - "desc": "Output dist data", + "desc": "Output data", "types": [ - "sf.model.ss_glm", "sf.model.ss_glm", "sf.model.sgb", "sf.model.ss_xgb", @@ -656,6 +654,17 @@ "b": true } } + }, + { + "name": "batch_size", + "desc": "Prediction batch size", + "type": "AT_INT", + "atomic": { + "isOptional": true, + "defaultValue": { + "i64": "100000" + } + } } ], "inputs": [ @@ -2147,6 +2156,52 @@ } ] }, + { + "domain": "preprocessing", + "name": "feature_calculate", + "desc": "Generate a new feature by performing calculations on an origin feature", + "version": "0.0.1", + "attrs": [ + { + "name": "rules", + "desc": "input CalculateOpRules rules", + "type": "AT_CUSTOM_PROTOBUF", + "customProtobufCls": "calculate_rules_pb2.CalculateOpRules" + } + ], + "inputs": [ + { + "name": "in_ds", + "desc": "Input vertical table", + "types": [ + "sf.table.vertical_table" + ], + "attrs": [ + { + "name": "features", + "desc": "Feature(s) to operate on", + "colMinCntInclusive": "1" + } + ] + } + ], + "outputs": [ + { + "name": "out_ds", + "desc": "output_dataset", + "types": [ + "sf.table.vertical_table" + ] + }, + { + "name": "out_rules", + "desc": "feature calculate rule", + "types": [ + "sf.rule.preprocessing" + ] + } + ] + }, { "domain": "preprocessing", "name": "feature_filter", @@ -2190,7 +2245,7 @@ "atomic": { "isOptional": true, "defaultValue": { - "s": "mean" + "s": "constant" }, "allowedValues": { "ss": [ @@ -2209,7 +2264,7 @@ "atomic": { "isOptional": true, "defaultValue": { - "s": "general_na" + "s": "custom_missing_value" } } }, @@ -2608,7 +2663,7 @@ "domain": "stats", "name": "groupby_statistics", "desc": "Get a groupby of statistics, like pandas groupby statistics.\nCurrently only support VDataframe.", - "version": "0.0.2", + "version": "0.0.3", "attrs": [ { "name": "aggregation_config", diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index dcd0e5a71..be3f68e30 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -19,7 +19,7 @@ COPY --from=builder /bin/nsjail /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/bin/ /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/lib/ /usr/local/lib/ -RUN yum install -y protobuf libnl3 && yum clean all +RUN yum install -y protobuf libnl3 libgomp && yum clean all RUN grep -rl '#!/root/miniconda3/envs/secretflow/bin' /usr/local/bin/ | xargs sed -i -e 's/#!\/root\/miniconda3\/envs\/secretflow/#!\/usr\/local/g' diff --git a/docker/release/anolis-lite.Dockerfile b/docker/release/anolis-lite.Dockerfile index 42e4197e1..6fb15140c 100644 --- a/docker/release/anolis-lite.Dockerfile +++ b/docker/release/anolis-lite.Dockerfile @@ -19,7 +19,7 @@ COPY --from=builder /bin/nsjail /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/bin/ /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/lib/ /usr/local/lib/ -RUN yum install -y protobuf libnl3 && yum clean all +RUN yum install -y protobuf libnl3 libgomp && yum clean all RUN grep -rl '#!/root/miniconda3/envs/secretflow/bin' /usr/local/bin/ | xargs sed -i -e 's/#!\/root\/miniconda3\/envs\/secretflow/#!\/usr\/local/g' diff --git a/docker/release/anolis.Dockerfile b/docker/release/anolis.Dockerfile index ba02a6e83..972dbf647 100644 --- a/docker/release/anolis.Dockerfile +++ b/docker/release/anolis.Dockerfile @@ -19,7 +19,7 @@ COPY --from=builder /bin/nsjail /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/bin/ /usr/local/bin/ COPY --from=python /root/miniconda3/envs/secretflow/lib/ /usr/local/lib/ -RUN yum install -y protobuf libnl3 && yum clean all +RUN yum install -y protobuf libnl3 libgomp && yum clean all RUN grep -rl '#!/root/miniconda3/envs/secretflow/bin' /usr/local/bin/ | xargs sed -i -e 's/#!\/root\/miniconda3\/envs\/secretflow/#!\/usr\/local/g' diff --git a/docker/translation.json b/docker/translation.json index 27da73b96..9313bb6a7 100644 --- a/docker/translation.json +++ b/docker/translation.json @@ -65,9 +65,9 @@ "map any input to output": "将任何输入映射到输出", "0.0.1": "0.0.1", "input_data": "input_data", - "Input dist data": "输入dist数据", + "Input data": "输入数据", "output_data": "output_data", - "Output dist data": "输出dist 数据" + "Output data": "输出数据" }, "io/read_data:0.0.1": { "io": "io", @@ -171,6 +171,8 @@ "Whether to save ids columns into output prediction table. If true, input feature_dataset must contain id columns, and receiver party must be id owner.": "是否将 id 列保存到输出预测表中;如果为 true,则输入feature_dataset必须包含 id 列,并且接收方必须是 id 所有者", "save_label": "保存标签列", "Whether or not to save real label columns into output pred file. If true, input feature_dataset must contain label columns and receiver party must be label owner.": "是否将真实的标签列保存到输出预测文件中;如果为 true,则输入feature_dataset必须包含标签列,并且接收方必须是标签所有者", + "batch_size": "batch_size", + "Prediction batch size": "预测批大小", "model": "模型", "feature_dataset": "特征数据集", "Input vertical table.": "输入联合表", @@ -449,12 +451,12 @@ "condition_filter": "条件筛选器", "Filter the table based on a single column's values and condition.\nWarning: the party responsible for condition filtering will directly send the sample distribution to other participants.\nMalicious participants can obtain the distribution of characteristics by repeatedly calling with different filtering values.\nAudit the usage of this component carefully.": "根据单个列的值和条件筛选表。\n警告:负责条件过滤的一方将直接将样本分发发送给其他参与者。\n恶意参与者可以通过使用不同的过滤值重复调用来获得特征的分布。\n仔细审核此组件的使用情况。", "0.0.1": "0.0.1", - "comparator": "比较器", - "Comparator to use for comparison. Must be one of '==','<','<=','>','>=','IN'": "用于比较的比较器。必须是'==='、'<'、'<='、'>'、'>='、'IN'之一", + "comparator": "比较条件", + "Comparator to use for comparison. Must be one of '==','<','<=','>','>=','IN'": "用于比较的条件。必须是'==='、'<'、'<='、'>'、'>='、'IN'之一", "value_type": "值类型", "Type of the value to compare with. Must be one of ['STRING', 'FLOAT']": "要与之进行比较的值的类型。必须是“STRING”、“FLOAT”中的一个", - "bound_value": "边界值", - "Input a str with values separated by ','. List of values to compare with. If comparator is not 'IN', we only support one element in this list.": "输入一个str,其值以“,”分隔。 表示比较的值的列表。如果comparator不是“IN”,则此列表中应该仅含一个元素。", + "bound_value": "条件值", + "Input a str with values separated by ','. List of values to compare with. If comparator is not 'IN', we only support one element in this list.": "输入一个str,其值以“,”分隔。 表示比较的值的列表。如果比较条件不是“IN”,则此列表中应该仅含一个元素。", "float_epsilon": "浮点数误差值", "Epsilon value for floating point comparison. WARNING: due to floating point representation in computers, set this number slightly larger if you want filter out the values exactly at desired boundary. for example, abs(1.001 - 1.002) is slightly larger than 0.001, and therefore may not be filter out using == and epsilson = 0.001": "用于浮点比较的Epsilon值。警告:由于计算机中的浮点表示,如果您想在所需的边界处过滤掉值,请将此数字设置得稍大一些。例如,abs(1.001-1.002)略大于0.001,因此可能无法使用==和epsilson=0.001进行过滤", "in_ds": "输入数据集", @@ -466,6 +468,22 @@ "out_ds_else": "输出数据集", "Output vertical table that does not satisfies the condition.": "输出不满足条件的垂直表格。" }, + "preprocessing/feature_calculate:0.0.1": { + "preprocessing": "预处理", + "feature_calculate": "特征计算", + "Generate a new feature by performing calculations on an origin feature": "对原特征进行操作生成新特征", + "0.0.1": "0.0.1", + "rules": "规则", + "input CalculateOpRules rules": "输入特征计算规则", + "in_ds": "输入数据集", + "Input vertical table": "输入联合表", + "features": "特征列", + "Feature(s) to operate on": "要操作的特征列", + "out_ds": "输出数据集", + "output_dataset": "输出数据集", + "out_rules": "输出规则", + "feature calculate rule": "特征计算规则" + }, "preprocessing/feature_filter:0.0.1": { "preprocessing": "预处理", "feature_filter": "特征过滤", @@ -482,8 +500,8 @@ "preprocessing": "预处理", "fillna": "异常值填充", "0.0.1": "0.0.1", - "strategy": "策略", - "The imputation strategy. If \"mean\", then replace missing values using the mean along each column. Can only be used with numeric data. If \"median\", then replace missing values using the median along each column. Can only be used with numeric data. If \"most_frequent\", then replace missing using the most frequent value along each column. Can be used with strings or numeric data. If there is more than one such value, only the smallest is returned. If \"constant\", then replace missing values with fill_value. Can be used with strings or numeric data.": "插补策略。如果为“平均值”,则使用每列的平均值替换缺失值。只能与数字数据一起使用。如果为“中值”,则使用每列的中值替换缺失值。只能与数字数据一起使用。如果为“most_frequency”,则使用每列中最频繁的值替换缺失的值。可以与字符串或数字数据一起使用。如果存在多个这样的值,则只返回最小的值。如果为“常量”,则用fill_value替换缺失的值。可以与字符串或数字数据一起使用。", + "strategy": "填充缺失值的方式", + "The imputation strategy. If \"mean\", then replace missing values using the mean along each column. Can only be used with numeric data. If \"median\", then replace missing values using the median along each column. Can only be used with numeric data. If \"most_frequent\", then replace missing using the most frequent value along each column. Can be used with strings or numeric data. If there is more than one such value, only the smallest is returned. If \"constant\", then replace missing values with fill_value. Can be used with strings or numeric data.": "插补策略。如果为“平均值”,则使用每列的平均值替换缺失值。只能与数字数据一起使用。如果为“中值”,则使用每列的中值替换缺失值。只能与数字数据一起使用。如果为“众数”,则使用每列中最频繁的值替换缺失的值。可以与字符串或数字数据一起使用。如果存在多个这样的值,则只返回最小的值。如果为“自定义值”,则用fill_value替换缺失的值。可以与字符串或数字数据一起使用。", "missing_value": "缺失值", "Which value should be treat as missing_value? int, float, str, general_na (includes np.nan, None or pandas.NA which are all null in sc.table), default=general_na": "哪个值应该被视为missing_value?int、float、str、general_na(包括sc.table中全部为null的np.nan、None或pandas.na),default=general_na", "missing_value_type": "缺失值类型", @@ -576,18 +594,18 @@ "test": "测试数据子集", "Output test dataset.": "输出测试数据子集" }, - "stats/groupby_statistics:0.0.2": { + "stats/groupby_statistics:0.0.3": { "stats": "统计", "groupby_statistics": "分组统计", "Get a groupby of statistics, like pandas groupby statistics.\nCurrently only support VDataframe.": "获取分组统计信息,参考pandas的分组统计。\n目前仅支持 VDataframe。", - "0.0.2": "0.0.2", + "0.0.3": "0.0.3", "aggregation_config": "聚合配置", "input groupby aggregation config": "输入聚合配置", "max_group_size": "最大组数", "The maximum number of groups allowed": "允许的最大组数", "input_data": "输入数据", "Input table.": "输入表", - "by": "组列", + "by": "特征列", "by what columns should we group the values": "我们应该按哪些列进行分组", "report": "报告", "Output groupby statistics report.": "输出分组统计信息报告" diff --git a/docs/awesome-pets/papers/applications/ppml/ppml_crypto.md b/docs/awesome-pets/papers/applications/ppml/ppml_crypto.md index a715e4f6c..34e0fc4d2 100644 --- a/docs/awesome-pets/papers/applications/ppml/ppml_crypto.md +++ b/docs/awesome-pets/papers/applications/ppml/ppml_crypto.md @@ -90,12 +90,12 @@ An overview of existing works is illustrated in the table below. Chinese Journal of Computers, [eprint in Chinese](http://cjc.ict.ac.cn/online/onlinepaper/hwl-202375100742.pdf) - When Machine Learning Meets Privacy: A Survey and Outlook. - *Bo Liu, Ming Ding, Sina Shaham, Wenny Rahayu, Farhad Farokhi, and Zihuai Lin* + *Bo Liu, Ming Ding, Sina Shaham, Wenny Rahayu, Farhad Farokhi, and Zihuai Lin* ACM Computing Surveys (CSUR), [eprint](https://arxiv.org/pdf/2011.11819.pdf) - Privacy-preserving machine learning: Methods, challenges and directions. - *Xu R, Baracaldo N, Joshi J* + *Xu R, Baracaldo N, Joshi J* arXiv preprint arXiv, [eprint](https://arxiv.org/pdf/2108.04417.pdf) ## Two-party Computation (2PC) diff --git a/requirements.txt b/requirements.txt index 9822be363..6aa9a551a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ secretflow-rayfed==0.2.0a7 # FEATURE=[lite] setuptools>=65.5.1 sparse>=0.14.0 spu==0.6.0.b0 # FEATURE=[lite] -sf-heu==0.5.0.dev20231118 # FEATURE=[lite] +sf-heu==0.5.0.dev20231128 # FEATURE=[lite] tensorflow-macos==2.11.0; platform_machine == "arm64" and platform_system == "Darwin" tensorflow==2.11.1; platform_machine != "arm64" tf2onnx>=1.13.0 @@ -40,4 +40,5 @@ wheel>=0.38.1 torch==2.1.1 torchmetrics==0.11.4 torchvision==0.16.1 -torchaudio==2.1.1 \ No newline at end of file +torchaudio==2.1.1 +interconnection==0.1.0.dev20231204 diff --git a/secretflow/component/batch_reader.py b/secretflow/component/batch_reader.py new file mode 100644 index 000000000..37c383061 --- /dev/null +++ b/secretflow/component/batch_reader.py @@ -0,0 +1,156 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, List + +import pandas as pd +import pyarrow.csv as csv + +from secretflow import PYU, reveal + + +class SimpleBatchReader: + def __init__(self, path, batch_size, col, to_numpy): + self.path = path + self.batch_size = batch_size + self.col = col + self.to_numpy = to_numpy + self.read_idx_in_batch = 0 + self.total_read_cnt = 0 + self.batch_idx = 0 + self.end = False + + def __iter__(self): + return self + + def read_next(self): + convert_options = csv.ConvertOptions() + if self.col: + convert_options.include_columns = self.col + reader = csv.open_csv(self.path, convert_options=convert_options) + batch = None + + for _ in range(self.batch_idx): + batch = reader.read_next_batch() + + res = [] + res_cnt = 0 + + while not self.end and res_cnt < self.batch_size: + if batch is None or self.read_idx_in_batch >= batch.num_rows: + try: + batch = reader.read_next_batch() + self.batch_idx += 1 + self.read_idx_in_batch = 0 + except StopIteration: + self.end = True + break + + if (batch.num_rows - self.read_idx_in_batch) > (self.batch_size - res_cnt): + res.append( + batch.slice( + self.read_idx_in_batch, self.batch_size - res_cnt + ).to_pandas() + ) + + self.read_idx_in_batch += self.batch_size - res_cnt + res_cnt += self.batch_size - res_cnt + else: + res.append(batch.slice(self.read_idx_in_batch).to_pandas()) + res_cnt += batch.num_rows - self.read_idx_in_batch + self.read_idx_in_batch = batch.num_rows + + self.total_read_cnt += res_cnt + + if res_cnt == 0: + return None + else: + res = pd.concat(res, axis=0) + if self.to_numpy: + return res.to_numpy() + else: + return res + + def __next__(self): + next_batch = self.read_next() + + if next_batch is None: + raise StopIteration + else: + return next_batch + + +class SimpleVerticalBatchReader: + def __init__( + self, + paths: Dict[str, str], + batch_size: int = 100000, + cols: Dict[str, List[str]] = None, + to_numpy=False, + ) -> None: + self.readers = {} + assert len(paths) > 1, "Cnt of parties should be greater than 1." + for party, path in paths.items(): + pyu = PYU(party) + self.readers[party] = pyu( + lambda path, batch_size, col, to_numpy: SimpleBatchReader( + path, batch_size, col, to_numpy + ) + )(path, batch_size, cols.get(party) if cols else None, to_numpy) + + def __iter__(self): + return self + + def __next__(self): + def read_next_batch_wrapper(reader): + next_batch = reader.read_next() + return next_batch is None, next_batch, reader + + batches = {} + end_flags = [] + + new_readers = {} + + for party, reader in self.readers.items(): + pyu = PYU(party) + end_flag, batch, new_reader = pyu(read_next_batch_wrapper, num_returns=3)( + reader + ) + + new_readers[party] = new_reader + + batches[party] = batch + end_flags.append(end_flag) + + self.readers = new_readers + + end_flags = reveal(end_flags) + + assert all( + x == end_flags[0] for x in end_flags + ), "end_flags are different between parties. Make sure the samples are aligned before. You may run PSI to fix." + + if end_flags[0]: + raise StopIteration + + return batches + + def total_read_cnt(self): + party = next(iter(self.readers)) + reader = self.readers[party] + + pyu = PYU(party) + cnt = pyu(lambda reader: reader.total_read_cnt)(reader) + return reveal(cnt) diff --git a/secretflow/component/data_utils.py b/secretflow/component/data_utils.py index e15b3bbd6..ea264bc5e 100644 --- a/secretflow/component/data_utils.py +++ b/secretflow/component/data_utils.py @@ -22,6 +22,7 @@ import numpy as np import pandas as pd +from secretflow.data.core.io import read_file_meta from secretflow.data.vertical import read_csv from secretflow.data.vertical.dataframe import VDataFrame from secretflow.device.device.pyu import PYU, PYUObject @@ -463,8 +464,11 @@ def dump_vertical_table( } wait(v_data.to_csv(output_path, index=False)) order = [p.party for p in v_data.partitions] + file_metas = {} + for pyu in output_path: + file_metas[pyu] = reveal(pyu(read_file_meta)(output_path[pyu])) logging.info( - f"dumped VDataFrame, file uri {output_path}, samples {parties_length}" + f"dumped VDataFrame, file uri {output_path}, samples {parties_length}, file meta {file_metas}" ) ret = DistData( @@ -647,6 +651,7 @@ def save_prediction_csv( label_keys: List[str] = None, id_df: pd.DataFrame = None, id_keys: List[str] = None, + try_append: bool = False, ) -> None: x = pd.DataFrame(pred_df, columns=[pred_key]) @@ -657,7 +662,15 @@ def save_prediction_csv( id = pd.DataFrame(id_df, columns=id_keys) x = pd.concat([x, id], axis=1) - x.to_csv(path, index=False) + import os + + if try_append: + if not os.path.isfile(path): + x.to_csv(path, index=False) + else: + x.to_csv(path, mode='a', header=False, index=False) + else: + x.to_csv(path, index=False) def gen_prediction_csv_meta( diff --git a/secretflow/component/entry.py b/secretflow/component/entry.py index d355491a7..56f52b5b8 100644 --- a/secretflow/component/entry.py +++ b/secretflow/component/entry.py @@ -30,6 +30,7 @@ from secretflow.component.preprocessing.binary_op import binary_op_comp from secretflow.component.preprocessing.case_when import case_when from secretflow.component.preprocessing.condition_filter import condition_filter_comp +from secretflow.component.preprocessing.feature_calculate import feature_calculate from secretflow.component.preprocessing.feature_filter import feature_filter_comp from secretflow.component.preprocessing.fillna import fillna from secretflow.component.preprocessing.onehot_encode import onehot_encode @@ -81,8 +82,10 @@ fillna, io_read_data, io_write_data, + feature_calculate, identity, ] + COMP_LIST_NAME = "secretflow" COMP_LIST_DESC = "First-party SecretFlow components." COMP_LIST_VERSION = "0.0.1" @@ -114,7 +117,7 @@ def generate_comp_list(): def get_comp_def(domain: str, name: str, version: str) -> ComponentDef: key = gen_key(domain, name, version) - assert key in COMP_MAP, f"key {key} is not in compute map {COMP_MAP}" + assert key in COMP_MAP, f"key {key} is not in component list {COMP_LIST}" return COMP_MAP[key].definition() diff --git a/secretflow/component/io/core/bins/woe_bin_utils.py b/secretflow/component/io/core/bins/woe_bin_utils.py index c0debf6f1..f32dd3d9a 100644 --- a/secretflow/component/io/core/bins/woe_bin_utils.py +++ b/secretflow/component/io/core/bins/woe_bin_utils.py @@ -67,7 +67,12 @@ def calculate_woe_from_ratios(rp: float, rn: float): return np.log(rp / rn) -def calculate_bin_ratios(bin_iv: float, bin_woe: float) -> Tuple[float, float]: +FLOAT_TOLERANCE = 1e-6 + + +def calculate_bin_ratios( + bin_iv: float, bin_woe: float, bin_total: int, feature_total: int +) -> Tuple[float, float]: """ Calculate two ratios: rp = bin_pos/total_positive, rn = bin_negative/total negative. @@ -79,6 +84,10 @@ def calculate_bin_ratios(bin_iv: float, bin_woe: float) -> Tuple[float, float]: bin_woe = log((bin_positives / total_positives) / (bin_negatives / total_negatives)) bin_iv = ((bin_positives / total_positives) - (bin_negatives / total_negatives)) * bin_woe """ + if abs(bin_woe) <= FLOAT_TOLERANCE: + x = bin_total * 1.0 / feature_total + y = x + return x, y D = bin_iv / bin_woe y = D / (np.exp(bin_woe) - 1) x = y + D @@ -98,10 +107,13 @@ def compute_bin_ratios(merged_rule: Dict) -> Dict: woes = variable["filling_values"] ivs = feature_iv["ivs"] num_bins = len(bin_counts) + feature_total = sum(bin_counts) bin_ratios = [ calculate_bin_ratios( bin_woe=woes[i], bin_iv=ivs[i], + bin_total=bin_counts[i], + feature_total=feature_total, ) for i in range(num_bins) ] diff --git a/secretflow/component/io/identity.py b/secretflow/component/io/identity.py index 2fd6d03aa..0b15a8175 100644 --- a/secretflow/component/io/identity.py +++ b/secretflow/component/io/identity.py @@ -16,9 +16,7 @@ from secretflow.component.component import CompEvalError, Component, IoType from secretflow.component.data_utils import DistDataType, model_dumps, model_loads - from secretflow.device.device.spu import SPU - from secretflow.spec.extend.data_pb2 import DeviceObjectCollection from secretflow.spec.v1.data_pb2 import DistData @@ -41,14 +39,14 @@ identity.io( io_type=IoType.INPUT, name="input_data", - desc="Input dist data", + desc="Input data", types=IDENTITY_SUPPORTED_TYPES, ) identity.io( io_type=IoType.OUTPUT, name="output_data", - desc="Output dist data", + desc="Output data", types=IDENTITY_SUPPORTED_TYPES, ) diff --git a/secretflow/component/ml/boost/sgb/sgb.py b/secretflow/component/ml/boost/sgb/sgb.py index 99ae0148e..c0663533e 100644 --- a/secretflow/component/ml/boost/sgb/sgb.py +++ b/secretflow/component/ml/boost/sgb/sgb.py @@ -14,9 +14,11 @@ import json import os +from secretflow.component.batch_reader import SimpleVerticalBatchReader from secretflow.component.component import Component, IoType, TableColParam from secretflow.component.data_utils import ( DistDataType, + extract_distdata_info, extract_table_header, gen_prediction_csv_meta, get_model_public_info, @@ -481,6 +483,13 @@ def sgb_train_eval_fn( types=[DistDataType.INDIVIDUAL_TABLE], col_params=None, ) +sgb_predict_comp.int_attr( + name="batch_size", + desc="Prediction batch size", + is_list=False, + is_optional=True, + default_value=100000, +) def load_sgb_model(ctx, pyus, model) -> SgbModel: @@ -531,88 +540,96 @@ def sgb_predict_eval_fn( pred, save_ids, save_label, + batch_size, ): model_public_info = get_model_public_info(model) - x = load_table( - ctx, + + v_headers = extract_table_header( feature_dataset, load_features=True, load_labels=True, col_selects=model_public_info['feature_selects'], ) - pyus = {p.party: p for p in x.partitions.keys()} - model = load_sgb_model(ctx, pyus, model) + parties_path_format = extract_distdata_info(feature_dataset) + filepaths = { + p: os.path.join(ctx.local_fs_wd, parties_path_format[p].uri) for p in v_headers + } - with ctx.tracer.trace_running(): - pyu = PYU(receiver) - pyu_y = model.predict(x, pyu) + cols = {k: list(v.keys()) for k, v in v_headers.items()} - y_path = os.path.join(ctx.local_fs_wd, pred) + feature_reader = SimpleVerticalBatchReader(filepaths, batch_size, cols, True) - if save_ids: - id_df = load_table(ctx, feature_dataset, load_ids=True) - assert pyu in id_df.partitions - id_header_map = extract_table_header(feature_dataset, load_ids=True) - assert receiver in id_header_map - id_header = list(id_header_map[receiver].keys()) - id_data = id_df.partitions[pyu].data - else: - id_header_map = None - id_header = None - id_data = None - - if save_label: - label_df = load_table( - ctx, - feature_dataset, - load_features=True, - load_labels=True, - col_selects=model_public_info['label_col'], - ) - assert pyu in label_df.partitions - label_header_map = extract_table_header( - feature_dataset, - load_features=True, - load_labels=True, - col_selects=model_public_info['label_col'], - ) - assert receiver in label_header_map - label_header = list(label_header_map[receiver].keys()) - label_data = label_df.partitions[pyu].data - else: - label_header_map = None - label_header = None - label_data = None - - wait( - pyu(save_prediction_csv)( - pyu_y.partitions[pyu], - pred_name, - y_path, - label_data, - label_header, - id_data, - id_header, - ) + pyus = {p: PYU(p) for p in v_headers.keys()} + sgb_model = load_sgb_model(ctx, pyus, model) + + if save_ids: + id_header_map = extract_table_header(feature_dataset, load_ids=True) + assert receiver in id_header_map + id_header = list(id_header_map[receiver].keys()) + id_reader = SimpleVerticalBatchReader( + filepaths, batch_size, id_header_map, True + ) + else: + id_header_map = None + id_header = None + id_reader = None + + if save_label: + label_header_map = extract_table_header( + feature_dataset, + load_features=True, + load_labels=True, + col_selects=model_public_info['label_col'], + ) + assert receiver in label_header_map + label_header = list(label_header_map[receiver].keys()) + label_reader = SimpleVerticalBatchReader( + filepaths, batch_size, label_header_map, True ) + else: + label_header_map = None + label_header = None + label_reader = None - y_db = DistData( - name=pred_name, - type=str(DistDataType.INDIVIDUAL_TABLE), - data_refs=[DistData.DataRef(uri=pred, party=receiver, format="csv")], - ) + try_append = False + with ctx.tracer.trace_running(): + receiver_pyu = PYU(receiver) + y_path = os.path.join(ctx.local_fs_wd, pred) + for batch in feature_reader: + new_batch = {PYU(party): batch[party] for party in v_headers} + pyu_y = sgb_model.predict(new_batch, receiver_pyu) + + wait( + receiver_pyu(save_prediction_csv)( + pyu_y.partitions[receiver_pyu], + pred_name, + y_path, + next(label_reader)[receiver] if label_reader else None, + label_header, + next(id_reader)[receiver] if id_reader else None, + id_header, + try_append, + ) + ) + try_append = True - meta = gen_prediction_csv_meta( - id_header=id_header_map, - label_header=label_header_map, - party=receiver, - pred_name=pred_name, - line_count=x.shape[0], - id_keys=id_header, - label_keys=label_header, - ) + y_db = DistData( + name=pred_name, + type=str(DistDataType.INDIVIDUAL_TABLE), + data_refs=[DistData.DataRef(uri=pred, party=receiver, format="csv")], + ) + + meta = gen_prediction_csv_meta( + id_header=id_header_map, + label_header=label_header_map, + party=receiver, + pred_name=pred_name, + line_count=feature_reader.total_read_cnt(), + id_keys=id_header, + label_keys=label_header, + ) - y_db.meta.Pack(meta) + y_db.meta.Pack(meta) - return {"pred": y_db} + return {"pred": y_db} diff --git a/secretflow/component/preprocessing/condition_filter.py b/secretflow/component/preprocessing/condition_filter.py index c5beaae52..a0a33668c 100644 --- a/secretflow/component/preprocessing/condition_filter.py +++ b/secretflow/component/preprocessing/condition_filter.py @@ -58,9 +58,10 @@ name="float_epsilon", desc="Epsilon value for floating point comparison. WARNING: due to floating point representation in computers, set this number slightly larger if you want filter out the values exactly at desired boundary. for example, abs(1.001 - 1.002) is slightly larger than 0.001, and therefore may not be filter out using == and epsilson = 0.001", is_list=False, - is_optional=False, + is_optional=True, lower_bound=0, lower_bound_inclusive=True, + default_value=0.000001, ) condition_filter_comp.io( diff --git a/secretflow/component/preprocessing/feature_calculate.py b/secretflow/component/preprocessing/feature_calculate.py new file mode 100644 index 000000000..9b5dffbd9 --- /dev/null +++ b/secretflow/component/preprocessing/feature_calculate.py @@ -0,0 +1,278 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pyarrow as pa +import math + +from google.protobuf.json_format import MessageToJson, Parse + +import secretflow.compute as sc +from secretflow.component.component import Component, IoType, TableColParam +from secretflow.component.data_utils import DistDataType +from secretflow.component.preprocessing.core.table_utils import ( + v_preprocessing_transform, +) +from secretflow.spec.extend.calculate_rules_pb2 import CalculateOpRules + +feature_calculate = Component( + "feature_calculate", + domain="preprocessing", + version="0.0.1", + desc="Generate a new feature by performing calculations on an origin feature", +) + +feature_calculate.custom_pb_attr( + name="rules", + desc="input CalculateOpRules rules", + pb_cls=CalculateOpRules, +) + +feature_calculate.io( + io_type=IoType.INPUT, + name="in_ds", + desc="Input vertical table", + types=[DistDataType.VERTICAL_TABLE], + col_params=[ + TableColParam( + name="features", + desc="Feature(s) to operate on", + col_min_cnt_inclusive=1, + ) + ], +) + +feature_calculate.io( + io_type=IoType.OUTPUT, + name="out_ds", + desc="output_dataset", + types=[DistDataType.VERTICAL_TABLE], + col_params=None, +) + +feature_calculate.io( + io_type=IoType.OUTPUT, + name="out_rules", + desc="feature calculate rule", + types=[DistDataType.PREPROCESSING_RULE], + col_params=None, +) + + +def apply_feature_calcute_rule( + table: sc.Table, rules: CalculateOpRules, in_ds_features +) -> sc.Table: + def _check_numuric(type): + assert pa.types.is_floating(type) or pa.types.is_integer( + type + ), f"operator only support float/int, but got {type}" + + def _check_text(type): + assert pa.types.is_string(type), f"operator only support string, but got {type}" + + # std = (x-mean)/stde + def _apply_standardize(col: sc.Array): + _check_numuric(col.dtype) + pd_col = col.to_pandas() + # const column, set column elements to 0s + if pd_col.nunique() == 1: + new_col = sc.multiply(col, 0) + else: + mean = pd_col.mean() + stde = pd_col.std(ddof=0) + new_col = sc.divide(sc.subtract(col, mean), stde) + return new_col + + # norm = (x-min)/(max-min) + def _apply_normalize(col: sc.Array): + _check_numuric(col.dtype) + pd_col = col.to_pandas() + # const column, set column elements to 0s + if pd_col.nunique() == 1: + new_col = sc.multiply(col, 0) + else: + max = pd_col.max() + min = pd_col.min() + new_col = sc.divide(sc.subtract(col, min), float(max - min)) + return new_col + + def _apply_range_limit(col: sc.Array): + _check_numuric(col.dtype) + op_cnt = len(rules.operands) + assert op_cnt == 2, f"range limit operator need 2 operands, but got {op_cnt}" + op0 = float(rules.operands[0]) + op1 = float(rules.operands[1]) + assert ( + op0 <= op1 + ), f"range limit operator expect min <= max, but get [{op0}, {op1}]" + + conds = [sc.less(col, op0), sc.greater(col, op1)] + cases = [op0, op1, col] + new_col = sc.case_when(sc.make_struct(*conds), *cases) + return new_col + + def _apply_unary(col: sc.Array): + _check_numuric(col.dtype) + op_cnt = len(rules.operands) + assert op_cnt == 3, f"unary operator needs 3 operands, but got {op_cnt}" + op0 = rules.operands[0] + assert op0 in ['+', '-'], f"unary op0 should be [+ - r], but get {op0}" + op1 = rules.operands[1] + assert op1 in [ + '+', + '-', + '*', + '/', + ], f"unary op1 should be [+ - * /], but get {op1}" + op3 = float(rules.operands[2]) + if op1 == "+": + new_col = sc.add(col, op3) + elif op1 == "-": + new_col = sc.subtract(col, op3) if op0 == "+" else sc.subtract(op3, col) + elif op1 == "*": + new_col = sc.multiply(col, op3) + elif op1 == "/": + if op0 == "+": + assert op3 != 0, "unary operator divide zero" + new_col = sc.divide(col, op3) + else: + new_col = sc.divide(op3, col) + return new_col + + def _apply_reciprocal(col: sc.Array): + _check_numuric(col.dtype) + new_col = sc.divide(1.0, col) + return new_col + + def _apply_round(col: sc.Array): + _check_numuric(col.dtype) + new_col = sc.round(col) + return new_col + + def _apply_log_round(col: sc.Array): + _check_numuric(col.dtype) + op_cnt = len(rules.operands) + assert op_cnt == 1, f"log operator needs 1 operands, but got {op_cnt}" + op0 = float(rules.operands[0]) + new_col = sc.round(sc.log2(sc.add(col, op0))) + return new_col + + def _apply_sqrt(col: sc.Array): + _check_numuric(col.dtype) + # TODO: whether check positive? sqrt will return a NaN when meets negative argument + new_col = sc.sqrt(col) + return new_col + + def _apply_log(col: sc.Array): + _check_numuric(col.dtype) + op_cnt = len(rules.operands) + assert op_cnt == 2, f"log operator needs 2 operands, but got {op_cnt}" + op0 = rules.operands[0] + op1 = float(rules.operands[1]) + if op0 == "e": + new_col = sc.multiply(sc.log2(sc.add(col, op1)), math.log(2, math.e)) + else: + new_col = sc.logb(sc.add(col, op1), float(op0)) + return new_col + + def _apply_exp(col: sc.Array): + _check_numuric(col.dtype) + new_col = sc.exp(col) + return new_col + + def _apply_lenth(col: sc.Array): + _check_text(col.dtype) + new_col = sc.utf8_length(col) + return new_col + + def _apply_substr(col: sc.Array): + _check_text(col.dtype) + op_cnt = len(rules.operands) + assert op_cnt == 2, f"substr operator need 2 oprands, but get {op_cnt}" + start = int(rules.operands[0]) + lenth = int(rules.operands[1]) + new_col = sc.utf8_slice_codeunits(col, start, start + lenth) + return new_col + + for feature in in_ds_features: + if feature in table.column_names: + col = table.column(feature) + if rules.op == CalculateOpRules.OpType.STANDARDIZE: + new_col = _apply_standardize(col) + elif rules.op == CalculateOpRules.OpType.NORMALIZATION: + new_col = _apply_normalize(col) + elif rules.op == CalculateOpRules.OpType.RANGE_LIMIT: + new_col = _apply_range_limit(col) + elif rules.op == CalculateOpRules.OpType.UNARY: + new_col = _apply_unary(col) + elif rules.op == CalculateOpRules.OpType.RECIPROCAL: + new_col = _apply_reciprocal(col) + elif rules.op == CalculateOpRules.OpType.ROUND: + new_col = _apply_round(col) + elif rules.op == CalculateOpRules.OpType.LOG_ROUND: + new_col = _apply_log_round(col) + elif rules.op == CalculateOpRules.OpType.SQRT: + new_col = _apply_sqrt(col) + elif rules.op == CalculateOpRules.OpType.LOG: + new_col = _apply_log(col) + elif rules.op == CalculateOpRules.OpType.EXP: + new_col = _apply_exp(col) + elif rules.op == CalculateOpRules.OpType.LENGTH: + new_col = _apply_lenth(col) + elif rules.op == CalculateOpRules.OpType.SUBSTR: + new_col = _apply_substr(col) + else: + raise AttributeError(f"unknown rules.op {rules.op}") + table = table.set_column( + table.column_names.index(feature), + feature, + new_col, + ) + return table + + +@feature_calculate.eval_fn +def feature_calculate_eval_fn( + *, + ctx, + rules: CalculateOpRules, + in_ds, + in_ds_features, + out_rules, + out_ds, +): + assert in_ds.type == DistDataType.VERTICAL_TABLE, "only support vtable for now" + str_rule = MessageToJson(rules) + + def _transform(data: pd.DataFrame): + import secretflow.spec.extend.calculate_rules_pb2 as pb + + rules = Parse(str_rule, pb.CalculateOpRules()) + data = apply_feature_calcute_rule( + sc.Table.from_pandas(data), rules, in_ds_features + ) + return data, [], None + + (out_ds, model_dd, _) = v_preprocessing_transform( + ctx, + in_ds, + in_ds_features, + _transform, + out_ds, + out_rules, + "Feature Calculate", + assert_one_party=False, + ) + + return {"out_rules": model_dd, "out_ds": out_ds} diff --git a/secretflow/component/preprocessing/fillna.py b/secretflow/component/preprocessing/fillna.py index 0884d5a0e..9289941de 100644 --- a/secretflow/component/preprocessing/fillna.py +++ b/secretflow/component/preprocessing/fillna.py @@ -49,7 +49,7 @@ """, is_list=False, is_optional=True, - default_value="mean", + default_value="constant", allowed_values=SUPPORTED_FILL_NA_METHOD, ) @@ -58,7 +58,7 @@ desc="Which value should be treat as missing_value? int, float, str, general_na (includes np.nan, None or pandas.NA which are all null in sc.table), default=general_na", is_list=False, is_optional=True, - default_value="general_na", + default_value="custom_missing_value", ) NA_SUPPORTED_TYPE_DICT = { @@ -223,7 +223,9 @@ def fit(data): data.columns ), f"strategy {strategy} works only on numerical columns, select only numerical columns" fillna_rules = generate_rule_dict(missing_value, data, strategy) - else: + elif strategy == "most_frequent": + fillna_rules = generate_rule_dict(missing_value, data, strategy) + elif strategy == "constant": fillna_rules = generate_rule_dict_constant( missing_value, fill_value_int, @@ -231,6 +233,8 @@ def fit(data): fill_value_str, data, ) + else: + assert ValueError(f"Unsupported strategy {strategy}") return fillna_rules def fillna_fit_transform(trans_data): diff --git a/secretflow/component/stats/groupby_statistics.py b/secretflow/component/stats/groupby_statistics.py index a953013e5..6efddaa63 100644 --- a/secretflow/component/stats/groupby_statistics.py +++ b/secretflow/component/stats/groupby_statistics.py @@ -36,7 +36,7 @@ groupby_statistics_comp = Component( name="groupby_statistics", domain="stats", - version="0.0.2", + version="0.0.3", desc="""Get a groupby of statistics, like pandas groupby statistics. Currently only support VDataframe. """, diff --git a/secretflow/compute/tracer.py b/secretflow/compute/tracer.py index 92d6a151d..d75ad06cc 100644 --- a/secretflow/compute/tracer.py +++ b/secretflow/compute/tracer.py @@ -300,6 +300,9 @@ def __init__(self, arrow: pa.ChunkedArray, trace: _Tracer): def dtype(self) -> pa.DataType: return self._arrow.type + def to_pandas(self) -> pd.Series: + return self._arrow.to_pandas() + class Table: def __init__(self, table: pa.Table, trace: _Tracer): diff --git a/secretflow/data/core/agent.py b/secretflow/data/core/agent.py index b7ced7d7c..3d8ed9fe7 100644 --- a/secretflow/data/core/agent.py +++ b/secretflow/data/core/agent.py @@ -291,6 +291,7 @@ def to_csv(self, idx: AgentIndex, filepath, **kwargs): """Save DataFrame to csv file.""" working_object = self.working_objects[idx] working_object.to_csv(filepath, **kwargs) + return True def iloc(self, idx: AgentIndex, index: Union[int, slice, List[int]]) -> AgentIndex: working_object = self.working_objects[idx] diff --git a/secretflow/data/core/io.py b/secretflow/data/core/io.py index c5b38f8ca..2807e226c 100644 --- a/secretflow/data/core/io.py +++ b/secretflow/data/core/io.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import platform from typing import Union import pandas as pd +def read_file_meta(path: str): + ret = {} + ret["ctime"] = os.path.getctime(path) + ret["mtime"] = os.path.getmtime(path) + ret["size"] = os.path.getsize(path) + if platform.system() == 'Linux': + ret["inode"] = os.stat(path).st_ino + return ret + + def read_csv_wrapper( filepath: str, auto_gen_header_prefix: str = "", read_backend="pandas", **kwargs ) -> Union[pd.DataFrame, "pl.DataFrame"]: diff --git a/secretflow/data/core/partition.py b/secretflow/data/core/partition.py index 2eeada49a..650fd4edb 100644 --- a/secretflow/data/core/partition.py +++ b/secretflow/data/core/partition.py @@ -264,7 +264,7 @@ def fillna( return Partition(self.part_agent, data_idx, self.device, self.backend) def to_csv(self, filepath, **kwargs): - self.part_agent.to_csv(self.agent_idx, filepath, **kwargs) + return self.part_agent.to_csv(self.agent_idx, filepath, **kwargs) def iloc(self, index: Union[int, slice, List[int]]) -> 'Partition': data_idx = self.part_agent.iloc(self.agent_idx, index) diff --git a/secretflow/data/core/polars/util.py b/secretflow/data/core/polars/util.py index c798d982c..3bd65581f 100644 --- a/secretflow/data/core/polars/util.py +++ b/secretflow/data/core/polars/util.py @@ -21,16 +21,17 @@ def read_polars_csv(filepath, *args, **kwargs): if 'delimiter' in kwargs and kwargs['delimiter'] is not None: - kwargs['separator'] = kwargs['delimiter'] + kwargs['separator'] = kwargs.pop('delimiter') if 'usecols' in kwargs and kwargs['usecols'] is not None: # polars only recognized list columns but not dictkeys. - kwargs['columns'] = list(kwargs['usecols']) + kwargs['columns'] = list(kwargs.pop('usecols')) if 'dtype' in kwargs and kwargs['dtype'] is not None: pl_dtypes = {} - for col, dt in kwargs['dtype'].items(): + for col, dt in kwargs.pop('dtype').items(): pl_dtypes[col] = infer_pl_dtype(dt) kwargs['dtypes'] = pl_dtypes - del kwargs['delimiter'], kwargs['dtype'], kwargs['usecols'] + if 'nrows' in kwargs and kwargs['nrows'] is not None: + kwargs['n_rows'] = kwargs.pop('nrows') df = pl.read_csv(filepath, *args, **kwargs) if len(df.columns) == 1: # for compatibility of pandas, single columns will drop null when read. diff --git a/secretflow/data/vertical/io.py b/secretflow/data/vertical/io.py index fd6d07ebc..85dbbabdf 100644 --- a/secretflow/data/vertical/io.py +++ b/secretflow/data/vertical/io.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + from typing import Dict, List, Union -from secretflow.device import Device, PYU, SPU +from secretflow.device import PYU, SPU, Device, reveal from secretflow.utils.errors import InvalidArgumentError from secretflow.utils.random import global_random from ..core import partition -from ..core.io import read_csv_wrapper +from ..core.io import read_csv_wrapper, read_file_meta from .dataframe import VDataFrame @@ -164,9 +165,14 @@ def get_keys( parties_length = {} for device, part in partitions.items(): parties_length[device.party] = len(part) - assert ( - len(set(parties_length.values())) == 1 - ), f"number of samples must be equal across all devices, got {parties_length}, input uri {filepath_actual}" + if len(set(parties_length.values())) > 1: + file_metas = {} + for pyu in filepath_actual: + file_metas[pyu] = reveal(pyu(read_file_meta)(filepath_actual[pyu])) + raise AssertionError( + f"number of samples must be equal across all devices, got {parties_length}, " + f"input uri {filepath_actual}, input file meta {file_metas}" + ) for device, part in partitions.items(): for col in part.columns: diff --git a/secretflow/device/driver.py b/secretflow/device/driver.py index cacd579f8..7ba7c8b25 100644 --- a/secretflow/device/driver.py +++ b/secretflow/device/driver.py @@ -587,14 +587,22 @@ def barrier(): reveal(barriers) -def shutdown(): +def shutdown(barrier_on_shutdown=True): """Disconnect the worker, and terminate processes started by secretflow.init(). This will automatically run at the end when a Python process that uses Ray exits. It is ok to run this twice in a row. The primary use case for this function is to cleanup state between tests. + + Args: + barrier_on_shutdown: whether barrier on shutdown. It's useful in some cases + , e.g., reusing the port between multi secretflow tasks. Possible side + effects that may come with it at the same time, e.g., alice exits + accidently and bob will wait forever since alice will never give bob a + feedback. The default value is True. """ - barrier() + if barrier_on_shutdown: + barrier() sfd.shutdown() diff --git a/secretflow/ic/handler/handler.py b/secretflow/ic/handler/handler.py index 69a018e98..e5f676ae8 100644 --- a/secretflow/ic/handler/handler.py +++ b/secretflow/ic/handler/handler.py @@ -15,8 +15,8 @@ import abc import logging from typing import Tuple, List -from secretflow.ic.proto.handshake.entry_pb2 import HandshakeRequest, HandshakeResponse -from secretflow.ic.proto.common.header_pb2 import OK +from interconnection.handshake.entry_pb2 import HandshakeRequest, HandshakeResponse +from interconnection.common.header_pb2 import OK from secretflow.ic.proxy import LinkProxy diff --git a/secretflow/ic/handler/protocol_family/phe.py b/secretflow/ic/handler/protocol_family/phe.py index 4df6b19d6..6ed6e9a8f 100644 --- a/secretflow/ic/handler/protocol_family/phe.py +++ b/secretflow/ic/handler/protocol_family/phe.py @@ -13,7 +13,7 @@ # limitations under the License. -from secretflow.ic.proto.handshake.protocol_family import phe_pb2 +from interconnection.handshake.protocol_family import phe_pb2 class PheConfig: diff --git a/secretflow/ic/handler/sgb_handler.py b/secretflow/ic/handler/sgb_handler.py index c71d6c850..ac64bbafe 100644 --- a/secretflow/ic/handler/sgb_handler.py +++ b/secretflow/ic/handler/sgb_handler.py @@ -21,10 +21,10 @@ from google.protobuf import any_pb2 from secretflow.ic.handler.protocol_family import phe from secretflow.ic.handler.algo import xgb -from secretflow.ic.proto.common import header_pb2 -from secretflow.ic.proto.handshake import entry_pb2 -from secretflow.ic.proto.handshake.algos import sgb_pb2 -from secretflow.ic.proto.handshake.protocol_family import phe_pb2 +from interconnection.common import header_pb2 +from interconnection.handshake import entry_pb2 +from interconnection.handshake.algos import sgb_pb2 +from interconnection.handshake.protocol_family import phe_pb2 from secretflow.ic.proxy import LinkProxy from secretflow.ic.handler.handler import IcHandler from secretflow.ic.handler import util diff --git a/secretflow/ic/proto/__init__.py b/secretflow/ic/proto/__init__.py deleted file mode 100644 index 724997cbd..000000000 --- a/secretflow/ic/proto/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/secretflow/ic/proto/common/__init__.py b/secretflow/ic/proto/common/__init__.py deleted file mode 100644 index 724997cbd..000000000 --- a/secretflow/ic/proto/common/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/secretflow/ic/proto/common/header.proto b/secretflow/ic/proto/common/header.proto deleted file mode 100644 index ac3dd004a..000000000 --- a/secretflow/ic/proto/common/header.proto +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// [Sphinx doc begin anchor: ResponseHeader] -syntax = "proto3"; - -package org.interconnection; - -// 31100xxx is the white box interconnection code segment -// 31100xxx 为引擎白盒互联互通号段 -enum ErrorCode { - OK = 0; - - GENERIC_ERROR = 31100000; - UNEXPECTED_ERROR = 31100001; - NETWORK_ERROR = 31100002; - - INVALID_REQUEST = 31100100; - INVALID_RESOURCE = 31100101; - - HANDSHAKE_REFUSED = 31100200; - UNSUPPORTED_VERSION = 31100201; - UNSUPPORTED_ALGO = 31100202; - UNSUPPORTED_PARAMS = 31100203; -} - -message ResponseHeader { - int32 error_code = 1; - string error_msg = 2; -} -// [Sphinx doc end anchor: ResponseHeader] diff --git a/secretflow/ic/proto/common/header_pb2.py b/secretflow/ic/proto/common/header_pb2.py deleted file mode 100644 index be40426a4..000000000 --- a/secretflow/ic/proto/common/header_pb2.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: secretflow/ic/proto/common/header.proto -"""Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'secretflow/ic/proto/common/header.proto\x12\x13org.interconnection\"7\n\x0eResponseHeader\x12\x12\n\nerror_code\x18\x01 \x01(\x05\x12\x11\n\terror_msg\x18\x02 \x01(\t*\xf3\x01\n\tErrorCode\x12\x06\n\x02OK\x10\x00\x12\x14\n\rGENERIC_ERROR\x10\xe0\x98\xea\x0e\x12\x17\n\x10UNEXPECTED_ERROR\x10\xe1\x98\xea\x0e\x12\x14\n\rNETWORK_ERROR\x10\xe2\x98\xea\x0e\x12\x16\n\x0fINVALID_REQUEST\x10\xc4\x99\xea\x0e\x12\x17\n\x10INVALID_RESOURCE\x10\xc5\x99\xea\x0e\x12\x18\n\x11HANDSHAKE_REFUSED\x10\xa8\x9a\xea\x0e\x12\x1a\n\x13UNSUPPORTED_VERSION\x10\xa9\x9a\xea\x0e\x12\x17\n\x10UNSUPPORTED_ALGO\x10\xaa\x9a\xea\x0e\x12\x19\n\x12UNSUPPORTED_PARAMS\x10\xab\x9a\xea\x0e\x62\x06proto3') - -_ERRORCODE = DESCRIPTOR.enum_types_by_name['ErrorCode'] -ErrorCode = enum_type_wrapper.EnumTypeWrapper(_ERRORCODE) -OK = 0 -GENERIC_ERROR = 31100000 -UNEXPECTED_ERROR = 31100001 -NETWORK_ERROR = 31100002 -INVALID_REQUEST = 31100100 -INVALID_RESOURCE = 31100101 -HANDSHAKE_REFUSED = 31100200 -UNSUPPORTED_VERSION = 31100201 -UNSUPPORTED_ALGO = 31100202 -UNSUPPORTED_PARAMS = 31100203 - - -_RESPONSEHEADER = DESCRIPTOR.message_types_by_name['ResponseHeader'] -ResponseHeader = _reflection.GeneratedProtocolMessageType('ResponseHeader', (_message.Message,), { - 'DESCRIPTOR' : _RESPONSEHEADER, - '__module__' : 'secretflow.ic.proto.common.header_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.ResponseHeader) - }) -_sym_db.RegisterMessage(ResponseHeader) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _ERRORCODE._serialized_start=122 - _ERRORCODE._serialized_end=365 - _RESPONSEHEADER._serialized_start=64 - _RESPONSEHEADER._serialized_end=119 -# @@protoc_insertion_point(module_scope) diff --git a/secretflow/ic/proto/handshake/__init__.py b/secretflow/ic/proto/handshake/__init__.py deleted file mode 100644 index 724997cbd..000000000 --- a/secretflow/ic/proto/handshake/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/secretflow/ic/proto/handshake/algos/sgb.proto b/secretflow/ic/proto/handshake/algos/sgb.proto deleted file mode 100644 index 0d14f41e4..000000000 --- a/secretflow/ic/proto/handshake/algos/sgb.proto +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package org.interconnection.v2.algos; - -//===================================// -// Protos used in HandshakeRequest // -//===================================// - -message SgbParamsProposal { - repeated int32 supported_versions = 1; - - // 训练第一棵树时是否仅采用主动参与方的样本列 - // 参见: https://arxiv.org/abs/1901.08755 Completely SecureBoost - bool support_completely_sgb = 100; - - // 是否启用行采样 - bool support_row_sample_by_tree = 101; - - // 是否启用列采样 - bool support_col_sample_by_tree = 102; -} - -//===================================// -// Protos used in HandshakeResponse // -//===================================// - -message SgbParamsResult { - // 版本号 - int32 version = 1; - - // 迭代次数 - int32 num_round = 2; - - // 树的最大深度 - int32 max_depth = 3; - - // 树训练的行采样率 - double row_sample_by_tree = 4; - - // 树训练的列采样率 - double col_sample_by_tree = 5; - - // 样本分桶的eps参数 - double bucket_eps = 6; - - bool use_completely_sgb = 100; -} diff --git a/secretflow/ic/proto/handshake/algos/sgb_pb2.py b/secretflow/ic/proto/handshake/algos/sgb_pb2.py deleted file mode 100644 index 41ce55474..000000000 --- a/secretflow/ic/proto/handshake/algos/sgb_pb2.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: secretflow/ic/proto/handshake/algos/sgb.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n-secretflow/ic/proto/handshake/algos/sgb.proto\x12\x1corg.interconnection.v2.algos\"\x97\x01\n\x11SgbParamsProposal\x12\x1a\n\x12supported_versions\x18\x01 \x03(\x05\x12\x1e\n\x16support_completely_sgb\x18\x64 \x01(\x08\x12\"\n\x1asupport_row_sample_by_tree\x18\x65 \x01(\x08\x12\"\n\x1asupport_col_sample_by_tree\x18\x66 \x01(\x08\"\xb0\x01\n\x0fSgbParamsResult\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x11\n\tnum_round\x18\x02 \x01(\x05\x12\x11\n\tmax_depth\x18\x03 \x01(\x05\x12\x1a\n\x12row_sample_by_tree\x18\x04 \x01(\x01\x12\x1a\n\x12\x63ol_sample_by_tree\x18\x05 \x01(\x01\x12\x12\n\nbucket_eps\x18\x06 \x01(\x01\x12\x1a\n\x12use_completely_sgb\x18\x64 \x01(\x08\x62\x06proto3') - - - -_SGBPARAMSPROPOSAL = DESCRIPTOR.message_types_by_name['SgbParamsProposal'] -_SGBPARAMSRESULT = DESCRIPTOR.message_types_by_name['SgbParamsResult'] -SgbParamsProposal = _reflection.GeneratedProtocolMessageType('SgbParamsProposal', (_message.Message,), { - 'DESCRIPTOR' : _SGBPARAMSPROPOSAL, - '__module__' : 'secretflow.ic.proto.handshake.algos.sgb_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.algos.SgbParamsProposal) - }) -_sym_db.RegisterMessage(SgbParamsProposal) - -SgbParamsResult = _reflection.GeneratedProtocolMessageType('SgbParamsResult', (_message.Message,), { - 'DESCRIPTOR' : _SGBPARAMSRESULT, - '__module__' : 'secretflow.ic.proto.handshake.algos.sgb_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.algos.SgbParamsResult) - }) -_sym_db.RegisterMessage(SgbParamsResult) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _SGBPARAMSPROPOSAL._serialized_start=80 - _SGBPARAMSPROPOSAL._serialized_end=231 - _SGBPARAMSRESULT._serialized_start=234 - _SGBPARAMSRESULT._serialized_end=410 -# @@protoc_insertion_point(module_scope) diff --git a/secretflow/ic/proto/handshake/entry.proto b/secretflow/ic/proto/handshake/entry.proto deleted file mode 100644 index 8213854a3..000000000 --- a/secretflow/ic/proto/handshake/entry.proto +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -import "google/protobuf/any.proto"; -import "secretflow/ic/proto/common/header.proto"; - -package org.interconnection.v2; - -message HandshakeVersionCheckHelper { - int32 version = 1; -} - -// [Sphinx doc begin anchor: HandshakeRequest] -// unified protocol for interconnection -message HandshakeRequest { - // 握手请求版本号, 当前等于 2 - int32 version = 1; - - //** META INFO **// - - // The sender's rank - int32 requester_rank = 2; - - //** AI/BI 算法层 **// - - // enum AlgoType - repeated int32 supported_algos = 3; - - // 算法详细握手参数,与 supported_algos 一一对应 - // SS-LR:learning_rate,optimizer,normalize - // ECDH-PSI:Nothing,skip - repeated google.protobuf.Any algo_params = 4; - - //** 安全算子层 **// - - // AI/BI 算法所需的 op 列到此处 - // op = enum OpType - // ECDH-PSI:Nothing,skip - repeated int32 ops = 5; - repeated google.protobuf.Any op_params = 6; - - //** 密码协议层 **// - - // protocol_family = enum ProtocolFamily - // SS: Protocol: [Semi2K, ABY3], FieldType, BeaverConfig, SerializeFormat - // ECC: Hash2Curve, EcGroup, SerializeFormat - // PHE: Protocol: [Paillier, EcElgamal], SerializeFormat - repeated int32 protocol_families = 7; - repeated google.protobuf.Any protocol_family_params = 8; - - //** 数据 IO **// - - // 定义 AI/BI 算法的输入和结果输出格式,不包括中间交互数据的格式 - // PSI: item_count、result_to_rank - // SS-LR: sample_size、feature_num、has_label, etc. - google.protobuf.Any io_param = 9; -} -// [Sphinx doc end anchor: HandshakeRequest] - -// [Sphinx doc begin anchor: HandshakeResponse] -message HandshakeResponse { - // response header - ResponseHeader header = 1; - - //** AI/BI 算法层 **// - - // algos = enum AlgoType - int32 algo = 2; - - // 算法详细握手参数 - // SS-LR:learning_rate,optimizer,normalize - // ECDH-PSI:Nothing,skip - google.protobuf.Any algo_param = 3; - - //** 安全算子层 **// - - // AI/BI 算法所需的 op 列到此处 - // op = enum OpType - // ECDH-PSI:Nothing,skip - repeated int32 ops = 4; - repeated google.protobuf.Any op_params = 5; - - //** 密码协议层 **// - - // protocol_family = enum ProtocolFamily - // SS: Protocol: [Semi2K, ABY3], FieldType, BeaverConfig, SerializeFormat - // ECC: Hash2Curve, EcGroup, SerializeFormat - // PHE: Protocol: [Paillier, EcElgamal], SerializeFormat - repeated int32 protocol_families = 6; - repeated google.protobuf.Any protocol_family_params = 7; - - //** 数据 IO **// - - // 定义 AI/BI 算法的输入和结果输出格式,不包括中间交互数据的格式 - // PSI: item_count、result_to_rank - // SS-LR: sample_size、feature_num、has_label, etc. - google.protobuf.Any io_param = 8; -} -// [Sphinx doc end anchor: HandshakeResponse] - -// [Sphinx doc begin anchor: AlgoType] -enum AlgoType { - ALGO_TYPE_UNSPECIFIED = 0; - ALGO_TYPE_ECDH_PSI = 1; - ALGO_TYPE_SS_LR = 2; - ALGO_TYPE_SGB = 3; -} -// [Sphinx doc end anchor: AlgoType] - -// [Sphinx doc begin anchor: OpType] -enum OpType { - OP_TYPE_UNSPECIFIED = 0; - OP_TYPE_SIGMOID = 1; -} -// [Sphinx doc end anchor: OpType] - -// [Sphinx doc begin anchor: ProtocolFamily] -enum ProtocolFamily { - PROTOCOL_FAMILY_UNSPECIFIED = 0; - PROTOCOL_FAMILY_ECC = 1; - PROTOCOL_FAMILY_SS = 2; - PROTOCOL_FAMILY_PHE = 3; -} -// [Sphinx doc end anchor: ProtocolFamily] diff --git a/secretflow/ic/proto/handshake/entry_pb2.py b/secretflow/ic/proto/handshake/entry_pb2.py deleted file mode 100644 index 0fd2c8db5..000000000 --- a/secretflow/ic/proto/handshake/entry_pb2.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: secretflow/ic/proto/handshake/entry.proto -"""Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from secretflow.ic.proto.common import header_pb2 as secretflow_dot_ic_dot_proto_dot_common_dot_header__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)secretflow/ic/proto/handshake/entry.proto\x12\x16org.interconnection.v2\x1a\x19google/protobuf/any.proto\x1a\'secretflow/ic/proto/common/header.proto\".\n\x1bHandshakeVersionCheckHelper\x12\x0f\n\x07version\x18\x01 \x01(\x05\"\xae\x02\n\x10HandshakeRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x16\n\x0erequester_rank\x18\x02 \x01(\x05\x12\x17\n\x0fsupported_algos\x18\x03 \x03(\x05\x12)\n\x0b\x61lgo_params\x18\x04 \x03(\x0b\x32\x14.google.protobuf.Any\x12\x0b\n\x03ops\x18\x05 \x03(\x05\x12\'\n\top_params\x18\x06 \x03(\x0b\x32\x14.google.protobuf.Any\x12\x19\n\x11protocol_families\x18\x07 \x03(\x05\x12\x34\n\x16protocol_family_params\x18\x08 \x03(\x0b\x32\x14.google.protobuf.Any\x12&\n\x08io_param\x18\t \x01(\x0b\x32\x14.google.protobuf.Any\"\xaf\x02\n\x11HandshakeResponse\x12\x33\n\x06header\x18\x01 \x01(\x0b\x32#.org.interconnection.ResponseHeader\x12\x0c\n\x04\x61lgo\x18\x02 \x01(\x05\x12(\n\nalgo_param\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x12\x0b\n\x03ops\x18\x04 \x03(\x05\x12\'\n\top_params\x18\x05 \x03(\x0b\x32\x14.google.protobuf.Any\x12\x19\n\x11protocol_families\x18\x06 \x03(\x05\x12\x34\n\x16protocol_family_params\x18\x07 \x03(\x0b\x32\x14.google.protobuf.Any\x12&\n\x08io_param\x18\x08 \x01(\x0b\x32\x14.google.protobuf.Any*e\n\x08\x41lgoType\x12\x19\n\x15\x41LGO_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12\x41LGO_TYPE_ECDH_PSI\x10\x01\x12\x13\n\x0f\x41LGO_TYPE_SS_LR\x10\x02\x12\x11\n\rALGO_TYPE_SGB\x10\x03*6\n\x06OpType\x12\x17\n\x13OP_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fOP_TYPE_SIGMOID\x10\x01*{\n\x0eProtocolFamily\x12\x1f\n\x1bPROTOCOL_FAMILY_UNSPECIFIED\x10\x00\x12\x17\n\x13PROTOCOL_FAMILY_ECC\x10\x01\x12\x16\n\x12PROTOCOL_FAMILY_SS\x10\x02\x12\x17\n\x13PROTOCOL_FAMILY_PHE\x10\x03\x62\x06proto3') - -_ALGOTYPE = DESCRIPTOR.enum_types_by_name['AlgoType'] -AlgoType = enum_type_wrapper.EnumTypeWrapper(_ALGOTYPE) -_OPTYPE = DESCRIPTOR.enum_types_by_name['OpType'] -OpType = enum_type_wrapper.EnumTypeWrapper(_OPTYPE) -_PROTOCOLFAMILY = DESCRIPTOR.enum_types_by_name['ProtocolFamily'] -ProtocolFamily = enum_type_wrapper.EnumTypeWrapper(_PROTOCOLFAMILY) -ALGO_TYPE_UNSPECIFIED = 0 -ALGO_TYPE_ECDH_PSI = 1 -ALGO_TYPE_SS_LR = 2 -ALGO_TYPE_SGB = 3 -OP_TYPE_UNSPECIFIED = 0 -OP_TYPE_SIGMOID = 1 -PROTOCOL_FAMILY_UNSPECIFIED = 0 -PROTOCOL_FAMILY_ECC = 1 -PROTOCOL_FAMILY_SS = 2 -PROTOCOL_FAMILY_PHE = 3 - - -_HANDSHAKEVERSIONCHECKHELPER = DESCRIPTOR.message_types_by_name['HandshakeVersionCheckHelper'] -_HANDSHAKEREQUEST = DESCRIPTOR.message_types_by_name['HandshakeRequest'] -_HANDSHAKERESPONSE = DESCRIPTOR.message_types_by_name['HandshakeResponse'] -HandshakeVersionCheckHelper = _reflection.GeneratedProtocolMessageType('HandshakeVersionCheckHelper', (_message.Message,), { - 'DESCRIPTOR' : _HANDSHAKEVERSIONCHECKHELPER, - '__module__' : 'secretflow.ic.proto.handshake.entry_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.HandshakeVersionCheckHelper) - }) -_sym_db.RegisterMessage(HandshakeVersionCheckHelper) - -HandshakeRequest = _reflection.GeneratedProtocolMessageType('HandshakeRequest', (_message.Message,), { - 'DESCRIPTOR' : _HANDSHAKEREQUEST, - '__module__' : 'secretflow.ic.proto.handshake.entry_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.HandshakeRequest) - }) -_sym_db.RegisterMessage(HandshakeRequest) - -HandshakeResponse = _reflection.GeneratedProtocolMessageType('HandshakeResponse', (_message.Message,), { - 'DESCRIPTOR' : _HANDSHAKERESPONSE, - '__module__' : 'secretflow.ic.proto.handshake.entry_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.HandshakeResponse) - }) -_sym_db.RegisterMessage(HandshakeResponse) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _ALGOTYPE._serialized_start=796 - _ALGOTYPE._serialized_end=897 - _OPTYPE._serialized_start=899 - _OPTYPE._serialized_end=953 - _PROTOCOLFAMILY._serialized_start=955 - _PROTOCOLFAMILY._serialized_end=1078 - _HANDSHAKEVERSIONCHECKHELPER._serialized_start=137 - _HANDSHAKEVERSIONCHECKHELPER._serialized_end=183 - _HANDSHAKEREQUEST._serialized_start=186 - _HANDSHAKEREQUEST._serialized_end=488 - _HANDSHAKERESPONSE._serialized_start=491 - _HANDSHAKERESPONSE._serialized_end=794 -# @@protoc_insertion_point(module_scope) diff --git a/secretflow/ic/proto/handshake/protocol_family/phe.proto b/secretflow/ic/proto/handshake/protocol_family/phe.proto deleted file mode 100644 index f6178c9ab..000000000 --- a/secretflow/ic/proto/handshake/protocol_family/phe.proto +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -import "google/protobuf/any.proto"; - -package org.interconnection.v2.protocol; - -//===================================// -// Protos used in HandshakeRequest // -//===================================// - -message PheProtocolProposal { - repeated int32 supported_versions = 1; - repeated int32 supported_phe_algos = 2; // enum PheAlgo - repeated google.protobuf.Any supported_phe_params = 3; -} - -enum PheAlgo { - PHE_ALGO_UNSPECIFIED = 0; - PHE_ALGO_PAILLIER = 1; - PHE_ALGO_OU = 2; - PHE_ALGO_EC_ELGAMAL = 3; -} - -message PaillierParamsProposal { - // common key sizes are 1024/2048/3072 - // 1024 only used for debug. 1024 比特仅用于联调,禁止用于生产环境 - repeated int32 key_sizes = 1; -} - -//===================================// -// Protos used in HandshakeResponse // -//===================================// - -message PheProtocolResult { - int32 version = 1; - int32 phe_algo = 2; // enum PheAlgo - google.protobuf.Any phe_param = 3; -} - -message PaillierParamsResult { - int32 key_size = 1; -} diff --git a/secretflow/ic/proto/handshake/protocol_family/phe_pb2.py b/secretflow/ic/proto/handshake/protocol_family/phe_pb2.py deleted file mode 100644 index 3ac3833d8..000000000 --- a/secretflow/ic/proto/handshake/protocol_family/phe_pb2.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: secretflow/ic/proto/handshake/protocol_family/phe.proto -"""Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7secretflow/ic/proto/handshake/protocol_family/phe.proto\x12\x1forg.interconnection.v2.protocol\x1a\x19google/protobuf/any.proto\"\x82\x01\n\x13PheProtocolProposal\x12\x1a\n\x12supported_versions\x18\x01 \x03(\x05\x12\x1b\n\x13supported_phe_algos\x18\x02 \x03(\x05\x12\x32\n\x14supported_phe_params\x18\x03 \x03(\x0b\x32\x14.google.protobuf.Any\"+\n\x16PaillierParamsProposal\x12\x11\n\tkey_sizes\x18\x01 \x03(\x05\"_\n\x11PheProtocolResult\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x10\n\x08phe_algo\x18\x02 \x01(\x05\x12\'\n\tphe_param\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\"(\n\x14PaillierParamsResult\x12\x10\n\x08key_size\x18\x01 \x01(\x05*d\n\x07PheAlgo\x12\x18\n\x14PHE_ALGO_UNSPECIFIED\x10\x00\x12\x15\n\x11PHE_ALGO_PAILLIER\x10\x01\x12\x0f\n\x0bPHE_ALGO_OU\x10\x02\x12\x17\n\x13PHE_ALGO_EC_ELGAMAL\x10\x03\x62\x06proto3') - -_PHEALGO = DESCRIPTOR.enum_types_by_name['PheAlgo'] -PheAlgo = enum_type_wrapper.EnumTypeWrapper(_PHEALGO) -PHE_ALGO_UNSPECIFIED = 0 -PHE_ALGO_PAILLIER = 1 -PHE_ALGO_OU = 2 -PHE_ALGO_EC_ELGAMAL = 3 - - -_PHEPROTOCOLPROPOSAL = DESCRIPTOR.message_types_by_name['PheProtocolProposal'] -_PAILLIERPARAMSPROPOSAL = DESCRIPTOR.message_types_by_name['PaillierParamsProposal'] -_PHEPROTOCOLRESULT = DESCRIPTOR.message_types_by_name['PheProtocolResult'] -_PAILLIERPARAMSRESULT = DESCRIPTOR.message_types_by_name['PaillierParamsResult'] -PheProtocolProposal = _reflection.GeneratedProtocolMessageType('PheProtocolProposal', (_message.Message,), { - 'DESCRIPTOR' : _PHEPROTOCOLPROPOSAL, - '__module__' : 'secretflow.ic.proto.handshake.protocol_family.phe_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.protocol.PheProtocolProposal) - }) -_sym_db.RegisterMessage(PheProtocolProposal) - -PaillierParamsProposal = _reflection.GeneratedProtocolMessageType('PaillierParamsProposal', (_message.Message,), { - 'DESCRIPTOR' : _PAILLIERPARAMSPROPOSAL, - '__module__' : 'secretflow.ic.proto.handshake.protocol_family.phe_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.protocol.PaillierParamsProposal) - }) -_sym_db.RegisterMessage(PaillierParamsProposal) - -PheProtocolResult = _reflection.GeneratedProtocolMessageType('PheProtocolResult', (_message.Message,), { - 'DESCRIPTOR' : _PHEPROTOCOLRESULT, - '__module__' : 'secretflow.ic.proto.handshake.protocol_family.phe_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.protocol.PheProtocolResult) - }) -_sym_db.RegisterMessage(PheProtocolResult) - -PaillierParamsResult = _reflection.GeneratedProtocolMessageType('PaillierParamsResult', (_message.Message,), { - 'DESCRIPTOR' : _PAILLIERPARAMSRESULT, - '__module__' : 'secretflow.ic.proto.handshake.protocol_family.phe_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.protocol.PaillierParamsResult) - }) -_sym_db.RegisterMessage(PaillierParamsResult) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _PHEALGO._serialized_start=436 - _PHEALGO._serialized_end=536 - _PHEPROTOCOLPROPOSAL._serialized_start=120 - _PHEPROTOCOLPROPOSAL._serialized_end=250 - _PAILLIERPARAMSPROPOSAL._serialized_start=252 - _PAILLIERPARAMSPROPOSAL._serialized_end=295 - _PHEPROTOCOLRESULT._serialized_start=297 - _PHEPROTOCOLRESULT._serialized_end=392 - _PAILLIERPARAMSRESULT._serialized_start=394 - _PAILLIERPARAMSRESULT._serialized_end=434 -# @@protoc_insertion_point(module_scope) diff --git a/secretflow/ic/proto/runtime/__init__.py b/secretflow/ic/proto/runtime/__init__.py deleted file mode 100644 index 724997cbd..000000000 --- a/secretflow/ic/proto/runtime/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/secretflow/ic/proto/runtime/data_exchange.proto b/secretflow/ic/proto/runtime/data_exchange.proto deleted file mode 100644 index fa49cc340..000000000 --- a/secretflow/ic/proto/runtime/data_exchange.proto +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package org.interconnection.v2.runtime; - -message DataExchangeProtocol { - // enum ScalarType - int32 scalar_type = 1; - - // if scalar_type is SCALAR_TYPE_OBJECT, please put real type name here - // otherwise this field is optional - string scalar_type_name = 2; - - oneof container { - // Store single scalar or single object - Scalar scalar = 5; - - // List, F means the size of each element is fixed and V means variant - FScalarList f_scalar_list = 6; - VScalarList v_scalar_list = 7; - - // Ndarray, F means the size of each element is fixed and V means variant - FNdArray f_ndarray = 8; - VNdArray v_ndarray = 9; - - // List of ndarray - // F means the size of each element is fixed and V means variant - FNdArrayList f_ndarray_list = 10; - VNdArrayList v_ndarray_list = 11; - } -} - -enum ScalarType { - SCALAR_TYPE_UNSPECIFIED = 0; - SCALAR_TYPE_BOOL = 1; - SCALAR_TYPE_INT8 = 2; - SCALAR_TYPE_UINT8 = 3; - SCALAR_TYPE_INT16 = 4; - SCALAR_TYPE_UINT16 = 5; - SCALAR_TYPE_INT32 = 6; - SCALAR_TYPE_UINT32 = 7; - SCALAR_TYPE_INT64 = 8; - SCALAR_TYPE_UINT64 = 9; - SCALAR_TYPE_INT128 = 10; - SCALAR_TYPE_UINT128 = 11; - - SCALAR_TYPE_FLOAT16 = 15; - SCALAR_TYPE_FLOAT32 = 16; - SCALAR_TYPE_FLOAT64 = 17; - - SCALAR_TYPE_OBJECT = 20; -} - -message Scalar { - bytes buf = 1; -} - -// Fixed-length scalar list -// the items in this list are all same size -message FScalarList { - // The size of each item is item_buf.len / item_count - int64 item_count = 1; - bytes item_buf = 2; -} - -// variant length scalar list -// variant length means that the length of each serialized element is not equal. -message VScalarList { - repeated bytes items = 1; -} - -message FNdArray { - // The size of each item is item_buf.len / total_item, where total_item count - // can get from shape - repeated int64 shape = 1; - bytes item_buf = 2; -} - -message VNdArray { - repeated int64 shape = 1; - repeated bytes items = 2; -} - -// List of ndarray -message FNdArrayList { - repeated FNdArray ndarrays = 1; -} - -message VNdArrayList { - repeated VNdArray ndarrays = 1; -} diff --git a/secretflow/ic/proto/runtime/data_exchange_pb2.py b/secretflow/ic/proto/runtime/data_exchange_pb2.py deleted file mode 100644 index e2d41776e..000000000 --- a/secretflow/ic/proto/runtime/data_exchange_pb2.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: secretflow/ic/proto/runtime/data_exchange.proto -"""Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n/secretflow/ic/proto/runtime/data_exchange.proto\x12\x1eorg.interconnection.v2.runtime\"\xa6\x04\n\x14\x44\x61taExchangeProtocol\x12\x13\n\x0bscalar_type\x18\x01 \x01(\x05\x12\x18\n\x10scalar_type_name\x18\x02 \x01(\t\x12\x38\n\x06scalar\x18\x05 \x01(\x0b\x32&.org.interconnection.v2.runtime.ScalarH\x00\x12\x44\n\rf_scalar_list\x18\x06 \x01(\x0b\x32+.org.interconnection.v2.runtime.FScalarListH\x00\x12\x44\n\rv_scalar_list\x18\x07 \x01(\x0b\x32+.org.interconnection.v2.runtime.VScalarListH\x00\x12=\n\tf_ndarray\x18\x08 \x01(\x0b\x32(.org.interconnection.v2.runtime.FNdArrayH\x00\x12=\n\tv_ndarray\x18\t \x01(\x0b\x32(.org.interconnection.v2.runtime.VNdArrayH\x00\x12\x46\n\x0e\x66_ndarray_list\x18\n \x01(\x0b\x32,.org.interconnection.v2.runtime.FNdArrayListH\x00\x12\x46\n\x0ev_ndarray_list\x18\x0b \x01(\x0b\x32,.org.interconnection.v2.runtime.VNdArrayListH\x00\x42\x0b\n\tcontainer\"\x15\n\x06Scalar\x12\x0b\n\x03\x62uf\x18\x01 \x01(\x0c\"3\n\x0b\x46ScalarList\x12\x12\n\nitem_count\x18\x01 \x01(\x03\x12\x10\n\x08item_buf\x18\x02 \x01(\x0c\"\x1c\n\x0bVScalarList\x12\r\n\x05items\x18\x01 \x03(\x0c\"+\n\x08\x46NdArray\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\x10\n\x08item_buf\x18\x02 \x01(\x0c\"(\n\x08VNdArray\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\r\n\x05items\x18\x02 \x03(\x0c\"J\n\x0c\x46NdArrayList\x12:\n\x08ndarrays\x18\x01 \x03(\x0b\x32(.org.interconnection.v2.runtime.FNdArray\"J\n\x0cVNdArrayList\x12:\n\x08ndarrays\x18\x01 \x03(\x0b\x32(.org.interconnection.v2.runtime.VNdArray*\x8d\x03\n\nScalarType\x12\x1b\n\x17SCALAR_TYPE_UNSPECIFIED\x10\x00\x12\x14\n\x10SCALAR_TYPE_BOOL\x10\x01\x12\x14\n\x10SCALAR_TYPE_INT8\x10\x02\x12\x15\n\x11SCALAR_TYPE_UINT8\x10\x03\x12\x15\n\x11SCALAR_TYPE_INT16\x10\x04\x12\x16\n\x12SCALAR_TYPE_UINT16\x10\x05\x12\x15\n\x11SCALAR_TYPE_INT32\x10\x06\x12\x16\n\x12SCALAR_TYPE_UINT32\x10\x07\x12\x15\n\x11SCALAR_TYPE_INT64\x10\x08\x12\x16\n\x12SCALAR_TYPE_UINT64\x10\t\x12\x16\n\x12SCALAR_TYPE_INT128\x10\n\x12\x17\n\x13SCALAR_TYPE_UINT128\x10\x0b\x12\x17\n\x13SCALAR_TYPE_FLOAT16\x10\x0f\x12\x17\n\x13SCALAR_TYPE_FLOAT32\x10\x10\x12\x17\n\x13SCALAR_TYPE_FLOAT64\x10\x11\x12\x16\n\x12SCALAR_TYPE_OBJECT\x10\x14\x62\x06proto3') - -_SCALARTYPE = DESCRIPTOR.enum_types_by_name['ScalarType'] -ScalarType = enum_type_wrapper.EnumTypeWrapper(_SCALARTYPE) -SCALAR_TYPE_UNSPECIFIED = 0 -SCALAR_TYPE_BOOL = 1 -SCALAR_TYPE_INT8 = 2 -SCALAR_TYPE_UINT8 = 3 -SCALAR_TYPE_INT16 = 4 -SCALAR_TYPE_UINT16 = 5 -SCALAR_TYPE_INT32 = 6 -SCALAR_TYPE_UINT32 = 7 -SCALAR_TYPE_INT64 = 8 -SCALAR_TYPE_UINT64 = 9 -SCALAR_TYPE_INT128 = 10 -SCALAR_TYPE_UINT128 = 11 -SCALAR_TYPE_FLOAT16 = 15 -SCALAR_TYPE_FLOAT32 = 16 -SCALAR_TYPE_FLOAT64 = 17 -SCALAR_TYPE_OBJECT = 20 - - -_DATAEXCHANGEPROTOCOL = DESCRIPTOR.message_types_by_name['DataExchangeProtocol'] -_SCALAR = DESCRIPTOR.message_types_by_name['Scalar'] -_FSCALARLIST = DESCRIPTOR.message_types_by_name['FScalarList'] -_VSCALARLIST = DESCRIPTOR.message_types_by_name['VScalarList'] -_FNDARRAY = DESCRIPTOR.message_types_by_name['FNdArray'] -_VNDARRAY = DESCRIPTOR.message_types_by_name['VNdArray'] -_FNDARRAYLIST = DESCRIPTOR.message_types_by_name['FNdArrayList'] -_VNDARRAYLIST = DESCRIPTOR.message_types_by_name['VNdArrayList'] -DataExchangeProtocol = _reflection.GeneratedProtocolMessageType('DataExchangeProtocol', (_message.Message,), { - 'DESCRIPTOR' : _DATAEXCHANGEPROTOCOL, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.DataExchangeProtocol) - }) -_sym_db.RegisterMessage(DataExchangeProtocol) - -Scalar = _reflection.GeneratedProtocolMessageType('Scalar', (_message.Message,), { - 'DESCRIPTOR' : _SCALAR, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.Scalar) - }) -_sym_db.RegisterMessage(Scalar) - -FScalarList = _reflection.GeneratedProtocolMessageType('FScalarList', (_message.Message,), { - 'DESCRIPTOR' : _FSCALARLIST, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.FScalarList) - }) -_sym_db.RegisterMessage(FScalarList) - -VScalarList = _reflection.GeneratedProtocolMessageType('VScalarList', (_message.Message,), { - 'DESCRIPTOR' : _VSCALARLIST, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.VScalarList) - }) -_sym_db.RegisterMessage(VScalarList) - -FNdArray = _reflection.GeneratedProtocolMessageType('FNdArray', (_message.Message,), { - 'DESCRIPTOR' : _FNDARRAY, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.FNdArray) - }) -_sym_db.RegisterMessage(FNdArray) - -VNdArray = _reflection.GeneratedProtocolMessageType('VNdArray', (_message.Message,), { - 'DESCRIPTOR' : _VNDARRAY, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.VNdArray) - }) -_sym_db.RegisterMessage(VNdArray) - -FNdArrayList = _reflection.GeneratedProtocolMessageType('FNdArrayList', (_message.Message,), { - 'DESCRIPTOR' : _FNDARRAYLIST, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.FNdArrayList) - }) -_sym_db.RegisterMessage(FNdArrayList) - -VNdArrayList = _reflection.GeneratedProtocolMessageType('VNdArrayList', (_message.Message,), { - 'DESCRIPTOR' : _VNDARRAYLIST, - '__module__' : 'secretflow.ic.proto.runtime.data_exchange_pb2' - # @@protoc_insertion_point(class_scope:org.interconnection.v2.runtime.VNdArrayList) - }) -_sym_db.RegisterMessage(VNdArrayList) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _SCALARTYPE._serialized_start=982 - _SCALARTYPE._serialized_end=1379 - _DATAEXCHANGEPROTOCOL._serialized_start=84 - _DATAEXCHANGEPROTOCOL._serialized_end=634 - _SCALAR._serialized_start=636 - _SCALAR._serialized_end=657 - _FSCALARLIST._serialized_start=659 - _FSCALARLIST._serialized_end=710 - _VSCALARLIST._serialized_start=712 - _VSCALARLIST._serialized_end=740 - _FNDARRAY._serialized_start=742 - _FNDARRAY._serialized_end=785 - _VNDARRAY._serialized_start=787 - _VNDARRAY._serialized_end=827 - _FNDARRAYLIST._serialized_start=829 - _FNDARRAYLIST._serialized_end=903 - _VNDARRAYLIST._serialized_start=905 - _VNDARRAYLIST._serialized_end=979 -# @@protoc_insertion_point(module_scope) diff --git a/secretflow/ic/proxy/serializer.py b/secretflow/ic/proxy/serializer.py index 78bc042bb..340024289 100644 --- a/secretflow/ic/proxy/serializer.py +++ b/secretflow/ic/proxy/serializer.py @@ -18,7 +18,7 @@ import numpy as np from heu import numpy as hnp, phe -from secretflow.ic.proto.runtime import data_exchange_pb2 as de +from interconnection.runtime import data_exchange_pb2 as de PublicKey = phe.PublicKey diff --git a/secretflow/ml/boost/core/data_preprocess.py b/secretflow/ml/boost/core/data_preprocess.py index 5c9322280..de978ce0f 100644 --- a/secretflow/ml/boost/core/data_preprocess.py +++ b/secretflow/ml/boost/core/data_preprocess.py @@ -20,7 +20,7 @@ from secretflow.data import FedNdarray, PartitionWay from secretflow.data.vertical import VDataFrame -from secretflow.device import PYUObject, wait +from secretflow.device import PYUObject, reveal, wait def prepare_dataset( @@ -49,20 +49,35 @@ def prepare_dataset( shape = ds.shape assert math.prod(shape), f"not support empty dataset, shape {shape}" - return ds, shape +def non_empty(x, worker): + return worker(lambda x: x.size > 0)(x) + + def validate( dataset, label ) -> Tuple[FedNdarray, Tuple[int, int], PYUObject, Tuple[int, int]]: x, x_shape = prepare_dataset(dataset) y, y_shape = prepare_dataset(label) assert len(x_shape) == 2, "only support 2D-array on dtrain" + data_check_task = [ worker(data_checks)(x_val, worker) for worker, x_val in x.partitions.items() ] + data_not_emmpty = reveal( + [non_empty(x_val, worker) for worker, x_val in x.partitions.items()] + ) + to_remove_devices = [] + for i, empty_device in enumerate(x.partitions.keys()): + if not data_not_emmpty[i]: + to_remove_devices.append(empty_device) + + for device in to_remove_devices: + x.partitions.pop(device) + assert len(y_shape) == 1 or y_shape[1] == 1, "label only support one label col" samples = y_shape[0] assert samples == x_shape[0], "dtrain & label are not aligned" @@ -71,6 +86,7 @@ def validate( y = list(y.partitions.values())[0] y = y.device(lambda y: y.reshape(-1, 1, order='F'))(y) y_shape = (samples, 1) + assert samples > 0, "cannot have empty samples" wait(data_check_task) return x, x_shape, y, y_shape diff --git a/secretflow/ml/boost/sgb_v/core/distributed_tree/distributed_tree.py b/secretflow/ml/boost/sgb_v/core/distributed_tree/distributed_tree.py index 8caeba807..a4221cd07 100644 --- a/secretflow/ml/boost/sgb_v/core/distributed_tree/distributed_tree.py +++ b/secretflow/ml/boost/sgb_v/core/distributed_tree/distributed_tree.py @@ -64,15 +64,14 @@ def predict(self, x: Dict[PYU, PYUObject]) -> PYUObject: Returns: PYUObject: _description_ """ - assert len(self.split_tree_dict) == len( - x - ), "data parition number should match split tree number" assert self.label_holder is not None, "label holder must exist" assert len(self.split_tree_dict) > 0, "number of split tree must be not empty" shape = None weight_selects = list() for pyu, split_tree in self.split_tree_dict.items(): + if pyu not in x: + continue s = pyu(lambda split_tree, x: split_tree.predict_leaf_select(x))( split_tree, x[pyu].data ) diff --git a/secretflow/ml/boost/sgb_v/factory/booster/global_ordermap_booster.py b/secretflow/ml/boost/sgb_v/factory/booster/global_ordermap_booster.py index 4048fc672..73cdce347 100644 --- a/secretflow/ml/boost/sgb_v/factory/booster/global_ordermap_booster.py +++ b/secretflow/ml/boost/sgb_v/factory/booster/global_ordermap_booster.py @@ -27,7 +27,7 @@ from ...model import SgbModel from ..components import DataPreprocessor, ModelBuilder, OrderMapManager, TreeTrainer -from ..components.component import Composite, Devices, print_params +from ..components.component import Composite, Devices, label_have_feature, print_params from ..sgb_actor import SGBActor @@ -129,13 +129,20 @@ def fit( else: x, x_shape, y, _ = self.components.preprocessor.validate(dataset, label) sample_num = x_shape[0] - # set devices devices = Devices(y.device, [*x.partitions.keys()], self.heu) + actors = [SGBActor(device=device) for device in devices.workers] + if not label_have_feature(devices): + logging.warning( + "label holder has no feature, setting first tree with label holder to be False." + ) + # disable train using label holder's device + self.set_params({"first_tree_with_label_holder_feature": False}) + # add label holder to actors + actors.append(SGBActor(device=devices.label_holder)) self.set_devices(devices) - # set actors - actors = [SGBActor(device=device) for device in devices.workers] + logging.debug("actors are created.") self.set_actors(actors) logging.debug("actors are set.") diff --git a/secretflow/ml/boost/sgb_v/factory/components/component.py b/secretflow/ml/boost/sgb_v/factory/components/component.py index 8bce9acef..0682179c0 100644 --- a/secretflow/ml/boost/sgb_v/factory/components/component.py +++ b/secretflow/ml/boost/sgb_v/factory/components/component.py @@ -28,6 +28,10 @@ class Devices: heu: HEU +def label_have_feature(devices: Devices): + return devices.label_holder in devices.workers + + class Component(abc.ABC): @abc.abstractmethod def show_params(self): diff --git a/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer.py b/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer.py index 6877e919e..40913310a 100644 --- a/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer.py +++ b/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Tuple, Union, List +from typing import List, Tuple, Union import numpy as np diff --git a/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer_actor.py b/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer_actor.py index a3831f4f9..e8f715d1d 100644 --- a/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer_actor.py +++ b/secretflow/ml/boost/sgb_v/factory/components/loss_computer/loss_computer_actor.py @@ -14,9 +14,9 @@ from typing import Tuple -import numpy as np import jax.numpy as jnp +import numpy as np from ....core.params import RegType from ....core.pure_numpy_ops.grad import ( diff --git a/secretflow/ml/boost/sgb_v/factory/components/order_map_manager/order_map_manager.py b/secretflow/ml/boost/sgb_v/factory/components/order_map_manager/order_map_manager.py index 59edd858e..bf1a9c2d5 100644 --- a/secretflow/ml/boost/sgb_v/factory/components/order_map_manager/order_map_manager.py +++ b/secretflow/ml/boost/sgb_v/factory/components/order_map_manager/order_map_manager.py @@ -47,6 +47,7 @@ def __init__(self) -> None: self.logging_params = LoggingParams() self.buckets = eps_inverse(self.params.sketch_eps) self.order_map_actors = [] + self.workers = [] def show_params(self): print_params(self.params) @@ -71,8 +72,12 @@ def get_params(self, params: dict): def set_devices(self, devices: Devices): self.workers = devices.workers - def set_actors(self, actors: SGBActor): - self.order_map_actors = actors + def set_actors(self, actors: List[SGBActor]): + assert len(self.workers) > 0, "workers must be set" + # worker actors only + self.order_map_actors = [ + actor for actor in actors if actor.device in self.workers + ] for i, actor in enumerate(self.order_map_actors): actor.register_class('OrderMapActor', OrderMapActor, i) diff --git a/secretflow/ml/boost/sgb_v/factory/components/sampler/sample_actor.py b/secretflow/ml/boost/sgb_v/factory/components/sampler/sample_actor.py index 843c383e2..a4debba0a 100644 --- a/secretflow/ml/boost/sgb_v/factory/components/sampler/sample_actor.py +++ b/secretflow/ml/boost/sgb_v/factory/components/sampler/sample_actor.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union, Tuple +import math +from typing import List, Tuple, Union import numpy as np -import math # handle order map building for one party @@ -28,9 +28,10 @@ def __init__(self, seed): def generate_one_partition_col_choices( self, colsample, feature_buckets: List[int] ) -> Tuple[Union[None, np.ndarray], int]: - if colsample < 1: + # if we only have one column left, do not sample + if colsample < 1 and sum(feature_buckets) > 1: feature_num = len(feature_buckets) - choices = math.ceil(feature_num * colsample) + choices = max([math.ceil(feature_num * colsample), 1]) col_choices = np.sort(self.rng.choice(feature_num, choices, replace=False)) buckets_count = 0 diff --git a/secretflow/ml/boost/sgb_v/factory/components/split_tree_builder/split_tree_builder.py b/secretflow/ml/boost/sgb_v/factory/components/split_tree_builder/split_tree_builder.py index 5596dbf0d..7b2def949 100644 --- a/secretflow/ml/boost/sgb_v/factory/components/split_tree_builder/split_tree_builder.py +++ b/secretflow/ml/boost/sgb_v/factory/components/split_tree_builder/split_tree_builder.py @@ -56,7 +56,11 @@ def set_devices(self, devices: Devices): self.label_holder = devices.label_holder def set_actors(self, actors: List[SGBActor]): - self.split_tree_builder_actors = actors + assert len(self.workers) > 0, "workers must be set" + # worker actors only + self.split_tree_builder_actors = [ + actor for actor in actors if actor.device in self.workers + ] for i, actor in enumerate(self.split_tree_builder_actors): actor.register_class('SplitTreeActor', SplitTreeActor, i) diff --git a/secretflow/ml/boost/sgb_v/model.py b/secretflow/ml/boost/sgb_v/model.py index 6ebec39d7..b6517723c 100644 --- a/secretflow/ml/boost/sgb_v/model.py +++ b/secretflow/ml/boost/sgb_v/model.py @@ -11,24 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +import os from pathlib import Path -from typing import Dict, Union, List +from typing import Dict, List, Union import jax.numpy as jnp from secretflow.data import FedNdarray, PartitionWay from secretflow.data.vertical import VDataFrame from secretflow.device import PYU, PYUObject, reveal, wait +from secretflow.ml.boost.core.data_preprocess import prepare_dataset -from .core.distributed_tree.distributed_tree import DistributedTree -from .core.distributed_tree.distributed_tree import from_dict as dt_from_dict +from .core.distributed_tree.distributed_tree import ( + DistributedTree, + from_dict as dt_from_dict, +) from .core.params import RegType -from secretflow.ml.boost.core.data_preprocess import prepare_dataset from .core.pure_numpy_ops.pred import sigmoid - common_path_postfix = "/common.json" leaf_weight_postfix = "/leaf_weight.json" split_tree_postfix = "/split_tree.json" @@ -56,7 +57,7 @@ def _insert_distributed_tree(self, tree: DistributedTree): def predict( self, - dtrain: Union[FedNdarray, VDataFrame], + dtrain: Union[FedNdarray, VDataFrame, Dict[PYU, PYUObject]], to_pyu: PYU = None, ) -> Union[PYUObject, FedNdarray]: """ @@ -76,18 +77,20 @@ def predict( """ if len(self.trees) == 0: return None - x, _ = prepare_dataset(dtrain) - x = x.partitions - preds = [] + + pred = 0 + if isinstance(dtrain, dict): + x = dtrain + else: + x, _ = prepare_dataset(dtrain) + x = x.partitions + for tree in self.trees: - pred = tree.predict(x) - preds.append(pred) + pred = self.label_holder(lambda x, y: jnp.add(x, y))(tree.predict(x), pred) - pred = self.label_holder( - lambda ps, base: ( - jnp.sum(jnp.concatenate(ps, axis=1), axis=1) + base - ).reshape(-1, 1) - )(preds, self.base) + pred = self.label_holder(lambda x, y: jnp.add(x, y).reshape(-1, 1))( + pred, self.base + ) if self.objective == RegType.Logistic: pred = self.label_holder(sigmoid)(pred) @@ -162,9 +165,12 @@ def json_dump(obj, path): finish_split_trees = [] for device, path in device_path_dict.items(): split_tree_path = path + split_tree_postfix - finish_split_trees.append( - device(json_dump)(model_dict['split_trees'][device], split_tree_path) - ) + if device in model_dict['split_trees']: + finish_split_trees.append( + device(json_dump)( + model_dict['split_trees'][device], split_tree_path + ) + ) # no real content, handler for wait r = (finish_common, finish_leaf, finish_split_trees) @@ -198,6 +204,10 @@ def build_split_tree_dict(i): return sm +def check_file_exists(path): + return True if os.path.isfile(path) else False + + def from_json_to_dict( device_path_dict: Dict, label_holder: PYU, @@ -220,6 +230,12 @@ def json_load(path): leaf_weight_path ) ] + # check split trees + split_devices = {} + for device, path in device_path_dict.items(): + if reveal(device(check_file_exists)(path + split_tree_postfix)): + split_devices[device] = path + return { 'label_holder': label_holder, 'common': common_params, @@ -230,7 +246,7 @@ def json_load(path): path + split_tree_postfix ) ] - for device, path in device_path_dict.items() + for device, path in split_devices.items() }, } diff --git a/secretflow/ml/boost/ss_xgb_v/model.py b/secretflow/ml/boost/ss_xgb_v/model.py index 46009f4b2..7dc0d591d 100644 --- a/secretflow/ml/boost/ss_xgb_v/model.py +++ b/secretflow/ml/boost/ss_xgb_v/model.py @@ -25,17 +25,17 @@ from secretflow.data.vertical import VDataFrame from secretflow.device import ( PYU, - SPU, PYUObject, + SPU, SPUCompilerNumReturnsPolicy, SPUObject, wait, ) +from secretflow.ml.boost.core.data_preprocess import prepare_dataset, validate from .core import node_split as split_fn from .core.node_split import RegType from .core.tree_worker import XgbTreeWorker as Worker -from secretflow.ml.boost.core.data_preprocess import prepare_dataset, validate class XgbModel: @@ -94,17 +94,13 @@ def predict( ), f"{len(x.partitions)}, {self.trees[0]}" self.workers = [Worker(0, device=pyu) for pyu in x.partitions] self.x = x.partitions - preds = [] + pred = 0 for idx in range(len(self.trees)): - pred = self._tree_pred(self.trees[idx], self.weights[idx]) - wait([pred]) - preds.append(pred) - - pred = self.spu( - lambda ps, base: ( - jnp.sum(jnp.concatenate(ps, axis=0), axis=0) + base - ).reshape(-1, 1) - )(preds, self.base) + pred = self.spu(lambda x, y: jnp.add(x, y))( + self._tree_pred(self.trees[idx], self.weights[idx]), pred + ) + + pred = self.spu(lambda x, y: jnp.add(x, y).reshape(-1, 1))(pred, self.base) if self.objective == RegType.Logistic: pred = self.spu(split_fn.sigmoid)(pred) diff --git a/secretflow/ml/nn/applications/sl_resnet_torch.py b/secretflow/ml/nn/applications/sl_resnet_torch.py new file mode 100644 index 000000000..a1f99bedc --- /dev/null +++ b/secretflow/ml/nn/applications/sl_resnet_torch.py @@ -0,0 +1,262 @@ +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +def conv3x3( + in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 +) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNetBase(nn.Module): + def __init__( + self, + block: BasicBlock, + layers: List[int], + input_channels: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + classifier: Optional[nn.Module] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.classifier = classifier + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + f"or a 3-element tuple, got {replace_stride_with_dilation}" + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + input_channels, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer( + self, + block: BasicBlock, + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + if self.classifier is not None: + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + def output_num(self): + return 1 + + +class ResNetFuse(nn.Module): + def __init__( + self, + num_classes: int = 10, + dnn_units_size: List[int] = [512 * 2], + dnn_activation: Optional[str] = "relu", + use_dropout: bool = False, + dropout: float = 0.5, + classifier: Optional[nn.Module] = None, + ): + super(ResNetFuse, self).__init__() + if classifier is not None: + self.classifier = classifier + else: + layers = [] + for i in range(1, len(dnn_units_size)): + layers.append(nn.Linear(dnn_units_size[i - 1], dnn_units_size[i])) + if dnn_activation is not None and dnn_activation == 'relu': + layers.append(nn.ReLU(True)) + if use_dropout: + layers.append(nn.Dropout(p=dropout)) + layers.append(nn.Linear(dnn_units_size[-1], num_classes)) + self.classifier = nn.Sequential(*layers) + + def forward(self, inputs): + fuse_input = torch.cat(inputs, dim=1) + outputs = self.classifier(fuse_input) + return outputs + + +# just for exp +class NaiveSumSoftmax(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) # just to pass optimizer in SLModel + self.layer = nn.Softmax(dim=-1) + + def forward(self, x: List[Tensor]) -> Tensor: + x = x[0] + x[1] + out = self.layer(x) + return out diff --git a/secretflow/ml/nn/applications/sl_vgg_torch.py b/secretflow/ml/nn/applications/sl_vgg_torch.py new file mode 100644 index 000000000..301dc068b --- /dev/null +++ b/secretflow/ml/nn/applications/sl_vgg_torch.py @@ -0,0 +1,182 @@ +from typing import Dict, List, Optional, Union, cast + +import torch +import torch.nn as nn +from torch import Tensor + +from secretflow.ml.nn.utils import BaseModule + +cfgs: Dict[str, List[Union[str, int]]] = { + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [ + 64, + 64, + "M", + 128, + 128, + "M", + 256, + 256, + 256, + "M", + 512, + 512, + 512, + "M", + 512, + 512, + 512, + "M", + ], + "D_Mini": [ + 64, + 64, + "M", + 128, + 128, + "M", + 256, + 256, + 256, + "M", + 512, + 512, + 512, + "M", + 512, + 512, + 512, + ], + "E": [ + 64, + 64, + "M", + 128, + 128, + "M", + 256, + 256, + 256, + 256, + "M", + 512, + 512, + 512, + 512, + "M", + 512, + 512, + 512, + 512, + "M", + ], +} + + +def make_layers( + cfg: List[Union[str, int]], input_channels: int = 3, batch_norm: bool = False +) -> nn.Sequential: + layers: List[nn.Module] = [] + for v in cfg: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + v = cast(int, v) + conv2d = nn.Conv2d(input_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + input_channels = v + return nn.Sequential(*layers) + + +class VGGBase(BaseModule): + def __init__( + self, + input_channels: int = 3, + features: Optional[nn.Module] = None, + init_weights: bool = True, + classifier=None, + ) -> None: + super().__init__() + self.features = ( + features + if features is not None + else make_layers( + cfgs["D_Mini"], input_channels=input_channels, batch_norm=False + ) + ) + self.avgpool = nn.AdaptiveAvgPool2d((3, 3)) + self.classifier = classifier + + if init_weights: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode="fan_out", nonlinearity="relu" + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.classifier is not None: + x = self.classifier(x) + return x + + def output_num(self): + return 1 + + +class VGGFuse(BaseModule): + def __init__( + self, + num_classes: int = 10, + dnn_units_size=[512 * 3 * 3 * 2, 4096, 4096], + dnn_activation="relu", + use_dropout=True, + dropout: float = 0.5, + classifier=None, + **kwargs + ): + super(VGGFuse, self).__init__() + if classifier is not None: + self.classifier = classifier + else: + layers = [] + for i in range(1, len(dnn_units_size)): + layers.append(nn.Linear(dnn_units_size[i - 1], dnn_units_size[i])) + if dnn_activation == 'relu': + layers.append(nn.ReLU(True)) + if use_dropout: + layers.append(nn.Dropout(p=dropout)) + layers.append(nn.Linear(dnn_units_size[-1], num_classes)) + self.classifier = nn.Sequential(*layers) + + def forward(self, inputs): + fuse_input = torch.cat(inputs, dim=1) + outputs = self.classifier(fuse_input) + return outputs + + +# just for attack exp +class NaiveSumSoftmax(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) # just to pass optimizer in SLModel + self.layer = nn.Softmax(dim=-1) + + def forward(self, x: List[Tensor]) -> Tensor: + x = x[0] + x[1] + out = self.layer(x) + return out diff --git a/secretflow/ml/nn/sl/backend/tensorflow/sl_base.py b/secretflow/ml/nn/sl/backend/tensorflow/sl_base.py index 634d23f5f..a02a4bcce 100644 --- a/secretflow/ml/nn/sl/backend/tensorflow/sl_base.py +++ b/secretflow/ml/nn/sl/backend/tensorflow/sl_base.py @@ -54,7 +54,7 @@ def __init__( self.eval_set = None self.valid_set = None self.tape = None - self.h = None + self._h = None self.train_x, self.train_y = None, None self.eval_x, self.eval_y = None, None self.kwargs = {} @@ -68,8 +68,13 @@ def __init__( self.eval_sample_weight = None self.fuse_callbacks = None self.predict_callbacks = None + self._data_x = None + self._gradient = None + self._training = True + self._pre_train_y = [] # record training status self.logs = None + self._pred_y = None self.steps_per_epoch = None self.skip_gradient = False if random_seed is not None: @@ -294,7 +299,7 @@ def build_dataset_from_builder( def build_dataset_from_csv( self, file_path: str, - label: str = None, + label: Optional[str] = None, s_w: Optional[np.ndarray] = None, batch_size=-1, shuffle=False, @@ -305,6 +310,10 @@ def build_dataset_from_csv( label_decoder=None, stage="train", ): + assert label is None or isinstance(label, str), ( + f"Got input label type {type(label)}, but default csv builder need str insead. " + f"Try use your own dataset builder" + ) data_set = tf.data.experimental.make_csv_dataset( file_path, batch_size=batch_size, @@ -378,59 +387,67 @@ def unpack_dataset(self, data, has_x, has_y, has_s_w): return data_x, data_y, data_s_w - def _reset_data_iter(self, stage): + def reset_data_iter(self, stage): if stage == "train": self.train_set = iter(self.train_dataset) elif stage == "eval": self.eval_set = iter(self.eval_dataset) + self._pre_train_y = [] - def base_forward(self, stage="train", step=0) -> ForwardData: - """compute hidden embedding - Args: - stage: Which stage of the base forward - Returns: hidden embedding - """ + def recv_gradient(self, gradient): + self._gradient = gradient - data_x = None + def get_batch_data(self, stage="train"): self.init_data() - training = True - if step == 0: - self._reset_data_iter(stage=stage) - + self._training = True if stage == "train": train_data = next(self.train_set) self.train_x, self.train_y, self.train_sample_weight = self.unpack_dataset( train_data, self.train_has_x, self.train_has_y, self.train_has_s_w ) - data_x = self.train_x + self._data_x = self.train_x + self._pre_train_y.append(self.train_y) elif stage == "eval": - training = False + self._training = False eval_data = next(self.eval_set) self.eval_x, self.eval_y, self.eval_sample_weight = self.unpack_dataset( eval_data, self.eval_has_x, self.eval_has_y, self.eval_has_s_w ) - data_x = self.eval_x + self._data_x = self.eval_x else: raise Exception("invalid stage") + def base_forward(self) -> ForwardData: + """compute hidden embedding + Args: + stage: Which stage of the base forward + Returns: hidden embedding + """ + # model_base is none equal to x is none if not self.model_base: return None # Strip tuple of length one, e.g: (x,) -> x - data_x = data_x[0] if isinstance(data_x, Tuple) and len(data_x) == 1 else data_x + self._data_x = ( + self._data_x[0] + if isinstance(self._data_x, Tuple) and len(self._data_x) == 1 + else self._data_x + ) self.tape = tf.GradientTape(persistent=True) with self.tape: - self.h = self._base_forward_internal(data_x, training=training) - self.data_x = data_x + self._h = self._base_forward_internal(self._data_x, training=self._training) + def pack_forward_data(self): + if not self.model_base: + return None forward_data = ForwardData() if len(self.model_base.losses) > 0: forward_data.losses = tf.add_n(self.model_base.losses) # The compressor can only recognize np type but not tensor. - forward_data.hidden = self.h.numpy() if tf.is_tensor(self.h) else self.h + forward_data.hidden = self._h.numpy() if tf.is_tensor(self._h) else self._h return forward_data def fuse_net( @@ -481,7 +498,7 @@ def fuse_net( return [None] * _num_returns return gradient - def base_backward(self, gradient): + def base_backward(self): """backward on fusenet Args: @@ -491,12 +508,12 @@ def base_backward(self, gradient): return_hiddens = [] with self.tape: - if len(gradient) == len(self.h): - for i in range(len(gradient)): - return_hiddens.append(self.fuse_op(self.h[i], gradient[i])) + if len(self._gradient) == len(self._h): + for i in range(len(self._gradient)): + return_hiddens.append(self.fuse_op(self._h[i], self._gradient[i])) else: - gradient = gradient[0] - return_hiddens.append(self.fuse_op(self.h, gradient)) + self._gradient = self._gradient[0] + return_hiddens.append(self.fuse_op(self._h, self._gradient)) # add model.losses into graph return_hiddens.append(self.model_base.losses) @@ -507,7 +524,7 @@ def base_backward(self, gradient): # clear intermediate results self.tape = None - self.h = None + self._h = None self.kwargs = {} def reset_metrics(self): @@ -670,6 +687,7 @@ def _fuse_net_internal(self, hiddens, losses, train_y, train_sample_weight): # Step 1: forward pass y_pred = self.model_fuse(hiddens, training=True, **self.kwargs) + self._pred_y = y_pred # Step 2: loss calculation, the loss function is configured in `compile()`. # if isinstance(self.model_fuse.loss, tfutils.custom_loss): # self.model_fuse.loss.with_kwargs(kwargs) diff --git a/secretflow/ml/nn/sl/backend/tensorflow/strategy/pipeline.py b/secretflow/ml/nn/sl/backend/tensorflow/strategy/pipeline.py index 1ab0f3f74..54de6a2e6 100644 --- a/secretflow/ml/nn/sl/backend/tensorflow/strategy/pipeline.py +++ b/secretflow/ml/nn/sl/backend/tensorflow/strategy/pipeline.py @@ -50,8 +50,22 @@ def __init__( self.trainable_vars = [] self.base_tape = [] self.fuse_tape = [] - self.h = [] - self.pre_train_y = [] + self._h = [] + self.hidden_list = [] + self._pre_train_y = [] + + def reset_data_iter(self, stage): + if stage == "train": + self.train_set = iter(self.train_dataset) + elif stage == "eval": + self.eval_set = iter(self.eval_dataset) + # reset some status + self.trainable_vars = [] + self.base_tape = [] + self.fuse_tape = [] + self._h = [] + self.hidden_list = [] + self._pre_train_y = [] def base_forward( self, stage="train", step=0, compress: bool = False @@ -66,91 +80,38 @@ def base_forward( self.model_base is not None ), "Base model cannot be none, please give model define or load a trained model" - data_x = None - training = True - self.init_data() - if step == 0: - self._reset_data_iter(stage=stage) - - if stage == "train": - train_data = next(self.train_set) - if self.train_has_y: - if self.train_has_s_w: - data_x = train_data[:-2] - train_y = train_data[-2] - self.train_sample_weight = train_data[-1] - else: - data_x = train_data[:-1] - train_y = train_data[-1] - # Label differential privacy - if self.label_dp is not None: - dp_train_y = self.label_dp(train_y.numpy()) - self.train_y = tf.convert_to_tensor(dp_train_y) - else: - self.train_y = train_y - self.pre_train_y.append(self.train_y) - else: - data_x = train_data - elif stage == "eval": - training = False - eval_data = next(self.eval_set) - if self.eval_has_y: - if self.eval_has_s_w: - data_x = eval_data[:-2] - eval_y = eval_data[-2] - self.eval_sample_weight = eval_data[-1] - else: - data_x = eval_data[:-1] - eval_y = eval_data[-1] - # Label differential privacy - if self.label_dp is not None: - dp_eval_y = self.label_dp(eval_y.numpy()) - self.eval_y = tf.convert_to_tensor(dp_eval_y) - else: - self.eval_y = eval_y - else: - data_x = eval_data - else: - raise Exception("invalid stage") - # Strip tuple of length one, e.g: (x,) -> x - data_x = data_x[0] if isinstance(data_x, Tuple) and len(data_x) == 1 else data_x + self._data_x = ( + self._data_x[0] + if isinstance(self._data_x, Tuple) and len(self._data_x) == 1 + else self._data_x + ) tape = tf.GradientTape(persistent=True) if stage == "train": self.base_tape.append(tape) self.trainable_vars.append(self.model_base.trainable_variables) with tape: - h = self._base_forward_internal( - data_x, - training=training, + self._h = self._base_forward_internal( + self._data_x, + training=self._training, ) if stage == "train": - self.h.append(h) - - self.data_x = data_x + self.hidden_list.append(self._h) - forward_data = ForwardData() - if len(self.model_base.losses) > 0: - forward_data.losses = tf.add_n(self.model_base.losses) - # TODO: only vaild on no server mode, refactor when use agglayer or server mode. - # no need to compress data on model_fuse side - forward_data.hidden = h - return forward_data - - def base_backward(self, gradient): + def base_backward(self): """backward on fusenet Args: gradient: gradient of fusenet hidden layer """ return_hiddens = [] - + gradient = self._gradient # TODO: only vaild on no server mode, refactor when use agglayer or server mode. # no need to decompress data on model_fuse side pre_tape = self.base_tape.pop(0) with pre_tape: - h = self.h.pop(0) + h = self.hidden_list.pop(0) if len(gradient) == len(h): for i in range(len(gradient)): return_hiddens.append(self.fuse_op(h[i], gradient[i])) @@ -173,7 +134,7 @@ def base_backward(self, gradient): self.kwargs = {} def _fuse_net_train(self, hiddens, losses=[]): - train_y = self.pre_train_y.pop(0) + train_y = self._pre_train_y.pop(0) return self._fuse_net_internal( hiddens, losses, @@ -181,19 +142,6 @@ def _fuse_net_train(self, hiddens, losses=[]): self.train_sample_weight, ) - def on_epoch_end(self, epoch): - # clean pipeline - self.trainable_vars = [] - self.base_tape = [] - self.fuse_tape = [] - self.h = [] - self.pre_train_y = [] - - if self.fuse_callbacks: - self.fuse_callbacks.on_epoch_end(epoch, self.epoch_logs) - self.training_logs = self.epoch_logs - return self.epoch_logs - @register_strategy(strategy_name='pipeline', backend='tensorflow') @proxy(PYUObject) diff --git a/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_async.py b/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_async.py index 8033f63f1..51d7d1d3d 100644 --- a/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_async.py +++ b/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_async.py @@ -66,21 +66,21 @@ def _base_forward_internal(self, data_x, use_dp: bool = True, training=True): h = self.embedding_dp(h) return h - def base_backward(self, gradient): + def base_backward(self): """backward on fusenet Args: gradient: gradient of fusenet hidden layer """ - + gradient = self._gradient for local_step in range(self.base_local_steps): return_hiddens = [] with self.tape: - if local_step == 0 and self.h is not None: - h = self.h + if local_step == 0 and self._h is not None: + h = self._h else: h = self._base_forward_internal( - self.data_x, + self._data_x, use_dp=False, training=True, # backward will only in training procedure ) @@ -99,12 +99,12 @@ def base_backward(self, gradient): # clear intermediate results self.tape = None - self.h = None - self.data_x = None + self._h = None + self._data_x = None self.kwargs = {} def _fuse_net_train(self, hiddens, losses=[]): - self.hiddens = copy.deepcopy(hiddens) + self._hiddens = copy.deepcopy(hiddens) return self._fuse_net_async_internal( hiddens, losses, @@ -164,7 +164,7 @@ def _fuse_net_async_internal( hidden_layer_gradients = [ grad + bound_param * (layer_var - h) for grad, layer_var, h in zip( - hidden_layer_gradients, hiddens, self.hiddens + hidden_layer_gradients, hiddens, self._hiddens ) ] hiddens = [ diff --git a/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_state_async.py b/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_state_async.py index b65766c9e..e64904ef4 100644 --- a/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_state_async.py +++ b/secretflow/ml/nn/sl/backend/tensorflow/strategy/split_state_async.py @@ -90,6 +90,7 @@ def _fuse_net_internal(self, hiddens, losses, train_y, train_sample_weight): # Step 1: forward pass y_pred = self.model_fuse(hiddens, training=True, **self.kwargs) + # Step 2: loss calculation, the loss function is configured in `compile()`. loss = self.model_fuse.compiled_loss( train_y, @@ -97,7 +98,7 @@ def _fuse_net_internal(self, hiddens, losses, train_y, train_sample_weight): sample_weight=train_sample_weight, regularization_losses=self.model_fuse.losses + losses, ) - + self._pred_y = y_pred # Step3: compute gradients trainable_vars = self.model_fuse.trainable_variables gradients = tape.gradient(loss, trainable_vars) diff --git a/secretflow/ml/nn/sl/backend/torch/sl_base.py b/secretflow/ml/nn/sl/backend/torch/sl_base.py index 0cb6f7d79..54697dd7f 100644 --- a/secretflow/ml/nn/sl/backend/torch/sl_base.py +++ b/secretflow/ml/nn/sl/backend/torch/sl_base.py @@ -67,7 +67,7 @@ def __init__( self.eval_iter = None self.valid_iter = None self.tape = None - self.h = None + self._h = None self.train_x, self.train_y = None, None self.eval_x, self.eval_y = None, None self.kwargs = {} @@ -80,6 +80,9 @@ def __init__( self.train_sample_weight = None self.eval_sample_weight = None self.fuse_callbacks = None + self._data_x = None # get_batch_data output + self._gradient = None + self._pred_y = None # record all logs of training on workers self.logs = None self.steps_per_epoch = None @@ -152,6 +155,21 @@ def get_basenet_output_num(self): else: return 0 + def recv_gradient(self, gradient): + self._gradient = gradient + + def pack_forward_data(self): + if not self.model_base: + return None + forward_data = ForwardData() + + # The compressor can only recognize np type but not tensor. + forward_data.hidden = ( + self._h.detach().numpy() if isinstance(self._h, torch.Tensor) else self._h + ) + # The compressor in forward can only recognize np type but not tensor. + return forward_data + def unpack_dataset(self, data, has_x, has_y, has_s_w): data_x, data_y, data_s_w = None, None, None # case: only has x or has y, and s_w is none @@ -180,7 +198,6 @@ def unpack_dataset(self, data, has_x, has_y, has_s_w): return data_x, data_y, data_s_w def get_batch_data(self, stage="train"): - data_x = None self.init_data() # init model stat to train @@ -196,7 +213,7 @@ def get_batch_data(self, stage="train"): ) = self.unpack_dataset( train_data, self.train_has_x, self.train_has_y, self.train_has_s_w ) - data_x = self.train_x + self._data_x = self.train_x elif stage == "eval": if self.model_base: @@ -206,18 +223,18 @@ def get_batch_data(self, stage="train"): self.eval_x, self.eval_y, self.eval_sample_weight = self.unpack_dataset( eval_data, self.eval_has_x, self.eval_has_y, self.eval_has_s_w ) - data_x = self.eval_x + self._data_x = self.eval_x else: raise Exception("invalid stage") # Strip tuple of length one, e.g: (x,) -> x - data_x = ( - data_x[0] - if isinstance(data_x, (Tuple, List)) and len(data_x) == 1 - else data_x + self._data_x = ( + self._data_x[0] + if isinstance(self._data_x, (Tuple, List)) and len(self._data_x) == 1 + else self._data_x ) - return data_x + return self._data_x def build_dataset_from_numeric( self, @@ -459,7 +476,7 @@ def fuse_net_internal(self, hiddens, train_y, train_sample_weight, logs): hiddens = hiddens.requires_grad_() y_pred = self.model_fuse(hiddens, **self.kwargs) - + self._pred_y = y_pred # Step 2: loss calculation. # NOTE: Refer to https://stackoverflow.com/questions/67730325/using-weights-in-crossentropyloss-and-bceloss-pytorch to use sample weight # custom loss will be re-open in the next version @@ -780,7 +797,7 @@ def get_fuse_weights(self): def get_stop_training(self): return False # currently not supported - def _reset_data_iter(self, stage): + def reset_data_iter(self, stage): if self.shuffle and self.random_seed: # FIXME: need a better way to handle global random state torch.manual_seed(self.random_seed) diff --git a/secretflow/ml/nn/sl/backend/torch/strategy/split_nn.py b/secretflow/ml/nn/sl/backend/torch/strategy/split_nn.py index c04de133a..9339fe2dd 100644 --- a/secretflow/ml/nn/sl/backend/torch/strategy/split_nn.py +++ b/secretflow/ml/nn/sl/backend/torch/strategy/split_nn.py @@ -31,31 +31,20 @@ class SLTorchModel(SLBaseTorchModel): - def base_forward(self, stage="train", step=0, **kwargs) -> Optional[ForwardData]: + def base_forward(self) -> Optional[ForwardData]: """compute hidden embedding Args: stage: Which stage of the base forward Returns: hidden embedding """ - if step == 0: - self._reset_data_iter(stage=stage) - data_x = self.get_batch_data(stage=stage) if not self.model_base: return None - self.h = self.base_forward_internal( - data_x, + self._h = self.base_forward_internal( + self._data_x, ) - forward_data = ForwardData() - # The compressor can only recognize np type but not tensor. - forward_data.hidden = ( - self.h.detach().numpy() if isinstance(self.h, torch.Tensor) else self.h - ) - # The compressor in forward can only recognize np type but not tensor. - return forward_data - - def base_backward(self, gradient): + def base_backward(self): """backward on fusenet Args: @@ -64,16 +53,16 @@ def base_backward(self, gradient): return_hiddens = [] - if len(gradient) == len(self.h): - for i in range(len(gradient)): - return_hiddens.append(self.fuse_op.apply(self.h[i], gradient[i])) + if len(self._gradient) == len(self._h): + for i in range(len(self._gradient)): + return_hiddens.append(self.fuse_op.apply(self._h[i], self._gradient[i])) else: - gradient = ( - gradient[0] - if isinstance(gradient[0], torch.Tensor) - else torch.tensor(gradient[0]) + self._gradient = ( + self._gradient[0] + if isinstance(self._gradient[0], torch.Tensor) + else torch.tensor(self._gradient[0]) ) - return_hiddens.append(self.fuse_op.apply(self.h, gradient)) + return_hiddens.append(self.fuse_op.apply(self._h, self._gradient)) # apply gradients for base net self.optim_base.zero_grad() @@ -84,7 +73,7 @@ def base_backward(self, gradient): # clear intermediate results self.tape = None - self.h = None + self._h = None self.kwargs = {} def fuse_net( diff --git a/secretflow/ml/nn/sl/sl_model.py b/secretflow/ml/nn/sl/sl_model.py index 7fcafef6a..c273150c4 100644 --- a/secretflow/ml/nn/sl/sl_model.py +++ b/secretflow/ml/nn/sl/sl_model.py @@ -239,7 +239,7 @@ def handle_data( def handle_file( self, train_dict: Dict[PYU, str], - label: str = None, + label: Union[str, List[str], Tuple[str]] = None, sample_weight: Union[FedNdarray, VDataFrame] = None, batch_size=32, shuffle=False, @@ -307,7 +307,7 @@ def fit( List[Union[HDataFrame, VDataFrame, FedNdarray]], Dict[PYU, str], ], - y: Union[VDataFrame, FedNdarray, PYUObject, str], + y: Union[VDataFrame, FedNdarray, PYUObject, str, List[str], Tuple[str]], batch_size=32, epochs=1, verbose=1, @@ -384,9 +384,9 @@ def fit( # build dataset train_x, train_y = x, y if isinstance(train_x, Dict): - assert isinstance(train_y, str), ( + assert isinstance(train_y, (str, List, Tuple)), ( f"When the input x is type of Dict, the data will read from files, " - f"and y must be a label name with type str." + f"and y must be a label name with type str/List[str]/Tuple[str]." ) steps_per_epoch = self.handle_file( train_x, @@ -473,27 +473,29 @@ def fit( report_list.append(f"epoch: {epoch+1}/{epochs} - ") self._workers[self.device_y].reset_metrics() callbacks.on_epoch_begin(epoch=epoch) - - hiddens_buf = [None] * (self.pipeline_size - 1) + [worker.reset_data_iter(stage="train") for worker in self._workers.values()] + f_data_buf = [None] * (self.pipeline_size - 1) for step in range(0, steps_per_epoch + self.pipeline_size - 1): if step < steps_per_epoch: - hiddens = {} + f_datas = {} callbacks.on_train_batch_begin(step) for device, worker in self._workers.items(): # 1. Local calculation of basenet - hidden = worker.base_forward(stage="train", step=step) - hiddens[device] = hidden - hiddens_buf.append(hiddens) + worker.get_batch_data(stage="train") + worker.base_forward() + f_data = worker.pack_forward_data() + f_datas[device] = f_data + f_data_buf.append(f_datas) # clean up buffer - hiddens = hiddens_buf.pop(0) + f_datas = f_data_buf.pop(0) # Async transfer hiddens to label side - if hiddens is None: + if f_datas is None: continue # During pipeline strategy, the backpropagation process of the model will lag n cycles behind the forward propagation process. step = step - self.pipeline_size + 1 # do agglayer forward - agg_hiddens = self.agglayer.forward(hiddens) + agg_hiddens = self.agglayer.forward(f_datas) # 3. Fusenet do local calculates and return gradients gradients = self._workers[self.device_y].fuse_net(agg_hiddens) @@ -510,7 +512,8 @@ def fit( scatter_gradients = self.agglayer.backward(gradients) for device, worker in self._workers.items(): if device in scatter_gradients.keys(): - worker.base_backward(scatter_gradients[device]) + worker.recv_gradient(scatter_gradients[device]) + worker.base_backward() # for EarlyStoppingBatch, evalute model every early_stopping_batch_step if ( @@ -529,13 +532,19 @@ def fit( callbacks.on_test_begin() res = [] + [ + worker.reset_data_iter(stage="eval") + for worker in self._workers.values() + ] for val_step in range(0, valid_steps): callbacks.on_test_batch_begin(batch=val_step) - hiddens = {} # driver end + f_datas = {} # driver end for device, worker in self._workers.items(): - hidden = worker.base_forward("eval", step=val_step) - hiddens[device] = hidden - agg_hiddens = self.agglayer.forward(hiddens) + worker.get_batch_data(stage="eval") + worker.base_forward() + f_data = worker.pack_forward_data() + f_datas[device] = f_data + agg_hiddens = self.agglayer.forward(f_datas) metrics = self._workers[self.device_y].evaluate(agg_hiddens) res.append(metrics) @@ -569,21 +578,26 @@ def fit( wait(res) res = [] assert ( - len(hiddens_buf) == 0 - ), f'hiddens buffer unfinished, len: {len(hiddens_buf)}' + len(f_data_buf) == 0 + ), f'hiddens buffer unfinished, len: {len(f_data_buf)}' if validation and epoch % validation_freq == 0: callbacks.on_test_begin() # validation self._workers[self.device_y].reset_metrics() - + [ + worker.reset_data_iter(stage="eval") + for worker in self._workers.values() + ] res = [] for step in range(0, valid_steps): callbacks.on_test_batch_begin(batch=step) - hiddens = {} # driver end + f_datas = {} # driver end for device, worker in self._workers.items(): - hidden = worker.base_forward("eval", step=step) - hiddens[device] = hidden - agg_hiddens = self.agglayer.forward(hiddens) + worker.get_batch_data(stage="eval") + worker.base_forward() + f_data = worker.pack_forward_data() + f_datas[device] = f_data + agg_hiddens = self.agglayer.forward(f_datas) metrics = self._workers[self.device_y].evaluate(agg_hiddens) res.append(metrics) @@ -679,13 +693,16 @@ def predict( wait_steps = min(min(self.get_cpus()) * 2, 100) res = [] callbacks.on_predict_begin() + [worker.reset_data_iter(stage="eval") for worker in self._workers.values()] for step in range(0, predict_steps): callbacks.on_predict_batch_begin(step) forward_data_dict = {} for device, worker in self._workers.items(): if device not in self.base_model_dict: continue - f_data = worker.base_forward(stage="eval", step=step) + worker.get_batch_data(stage="eval") + worker.base_forward() + f_data = worker.pack_forward_data() forward_data_dict[device] = f_data agg_hiddens = self.agglayer.forward(forward_data_dict) @@ -769,15 +786,18 @@ def evaluate( steps=evaluate_steps, ) callbacks.on_test_begin() + [worker.reset_data_iter(stage="eval") for worker in self._workers.values()] wait_steps = min(min(self.get_cpus()) * 2, 100) for step in range(0, evaluate_steps): callbacks.on_test_batch_begin(step) - hiddens = {} # driver端 + f_datas = {} # driver端 for device, worker in self._workers.items(): - hidden = worker.base_forward(stage="eval", step=step) - hiddens[device] = hidden + worker.get_batch_data(stage="eval") + worker.base_forward() + f_data = worker.pack_forward_data() + f_datas[device] = f_data - agg_hiddens = self.agglayer.forward(hiddens) + agg_hiddens = self.agglayer.forward(f_datas) metrics = self._workers[self.device_y].evaluate(agg_hiddens) if (step + 1) % wait_steps == 0: diff --git a/secretflow/protos/secretflow/spec/extend/calculate_rules.proto b/secretflow/protos/secretflow/spec/extend/calculate_rules.proto index bcfd4f072..a2438771b 100644 --- a/secretflow/protos/secretflow/spec/extend/calculate_rules.proto +++ b/secretflow/protos/secretflow/spec/extend/calculate_rules.proto @@ -29,28 +29,28 @@ message CalculateOpRules { NORMALIZATION = 2; // len(operands) == 2, [min, max] RANGE_LIMIT = 3; - // len(operands) == 3, [unary_op(+ - * /), value, value_is_right] - // value_is_right ∈ [true, false] - // if value_is_right == true, column unary_op value - // if value_is_right == false, value unary_op column + // len(operands) == 3, [(+ -), unary_op(+ - * /), value] + // if operandsp[0] == "+", column unary_op value + // if operandsp[0] == "-", value unary_op column UNARY = 4; // len(operands) == 0 - ROUND = 5; + RECIPROCAL = 5; + // len(operands) == 0 + ROUND = 6; // len(operands) == 1, [bias] - LOG_ROUND = 6; + LOG_ROUND = 7; // len(operands) == 0 - SQRT = 7; + SQRT = 8; // len(operands) == 2, [log_base, bias] - LOG = 8; - // len(operands) == 0 - EXP = 9; + LOG = 9; // len(operands) == 0 - CONCAT = 10; + EXP = 10; // len(operands) == 0 LENGTH = 11; // len(operands) == 2, [start_pos, length] SUBSTR = 12; } + OpType op = 1; repeated string operands = 2; diff --git a/secretflow/spec/extend/calculate_rules_pb2.py b/secretflow/spec/extend/calculate_rules_pb2.py new file mode 100644 index 000000000..60a40ef36 --- /dev/null +++ b/secretflow/spec/extend/calculate_rules_pb2.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: secretflow/protos/secretflow/spec/extend/calculate_rules.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n>secretflow/protos/secretflow/spec/extend/calculate_rules.proto\x12\x16secretflow.spec.extend\"\xab\x02\n\x10\x43\x61lculateOpRules\x12;\n\x02op\x18\x01 \x01(\x0e\x32/.secretflow.spec.extend.CalculateOpRules.OpType\x12\x10\n\x08operands\x18\x02 \x03(\t\x12\x14\n\x0cnew_col_name\x18\x03 \x01(\t\"\xb1\x01\n\x06OpType\x12\t\n\x05INVAL\x10\x00\x12\x0f\n\x0bSTANDARDIZE\x10\x01\x12\x11\n\rNORMALIZATION\x10\x02\x12\x0f\n\x0bRANGE_LIMIT\x10\x03\x12\t\n\x05UNARY\x10\x04\x12\x0e\n\nRECIPROCAL\x10\x05\x12\t\n\x05ROUND\x10\x06\x12\r\n\tLOG_ROUND\x10\x07\x12\x08\n\x04SQRT\x10\x08\x12\x07\n\x03LOG\x10\t\x12\x07\n\x03\x45XP\x10\n\x12\n\n\x06LENGTH\x10\x0b\x12\n\n\x06SUBSTR\x10\x0c\x42\x1c\n\x1aorg.secretflow.spec.extendb\x06proto3' +) + + +_CALCULATEOPRULES = DESCRIPTOR.message_types_by_name['CalculateOpRules'] +_CALCULATEOPRULES_OPTYPE = _CALCULATEOPRULES.enum_types_by_name['OpType'] +CalculateOpRules = _reflection.GeneratedProtocolMessageType( + 'CalculateOpRules', + (_message.Message,), + { + 'DESCRIPTOR': _CALCULATEOPRULES, + '__module__': 'secretflow.protos.secretflow.spec.extend.calculate_rules_pb2' + # @@protoc_insertion_point(class_scope:secretflow.spec.extend.CalculateOpRules) + }, +) +_sym_db.RegisterMessage(CalculateOpRules) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\032org.secretflow.spec.extend' + _CALCULATEOPRULES._serialized_start = 91 + _CALCULATEOPRULES._serialized_end = 390 + _CALCULATEOPRULES_OPTYPE._serialized_start = 213 + _CALCULATEOPRULES_OPTYPE._serialized_end = 390 +# @@protoc_insertion_point(module_scope) diff --git a/secretflow/ic/proto/handshake/algos/__init__.py b/tests/component/infra/__init__.py similarity index 100% rename from secretflow/ic/proto/handshake/algos/__init__.py rename to tests/component/infra/__init__.py diff --git a/tests/component/test_component.py b/tests/component/infra/test_component.py similarity index 100% rename from tests/component/test_component.py rename to tests/component/infra/test_component.py diff --git a/tests/component/test_eval_param_reader.py b/tests/component/infra/test_eval_param_reader.py similarity index 100% rename from tests/component/test_eval_param_reader.py rename to tests/component/infra/test_eval_param_reader.py diff --git a/tests/component/test_i18n.py b/tests/component/infra/test_i18n.py similarity index 100% rename from tests/component/test_i18n.py rename to tests/component/infra/test_i18n.py diff --git a/secretflow/ic/proto/handshake/protocol_family/__init__.py b/tests/component/io/__init__.py similarity index 100% rename from secretflow/ic/proto/handshake/protocol_family/__init__.py rename to tests/component/io/__init__.py diff --git a/tests/component/test_identity.py b/tests/component/io/test_identity.py similarity index 100% rename from tests/component/test_identity.py rename to tests/component/io/test_identity.py diff --git a/tests/component/test_io_ss_glm.py b/tests/component/io/test_io_ss_glm.py similarity index 100% rename from tests/component/test_io_ss_glm.py rename to tests/component/io/test_io_ss_glm.py diff --git a/tests/component/test_io_vert_bin.py b/tests/component/io/test_io_vert_bin.py similarity index 100% rename from tests/component/test_io_vert_bin.py rename to tests/component/io/test_io_vert_bin.py diff --git a/tests/component/test_io_woe_bin.py b/tests/component/io/test_io_woe_bin.py similarity index 100% rename from tests/component/test_io_woe_bin.py rename to tests/component/io/test_io_woe_bin.py diff --git a/tests/component/ml/__init__.py b/tests/component/ml/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/component/test_biclassification_eval.py b/tests/component/ml/test_biclassification_eval.py similarity index 100% rename from tests/component/test_biclassification_eval.py rename to tests/component/ml/test_biclassification_eval.py diff --git a/tests/component/test_prediction_bias_eval.py b/tests/component/ml/test_prediction_bias_eval.py similarity index 100% rename from tests/component/test_prediction_bias_eval.py rename to tests/component/ml/test_prediction_bias_eval.py diff --git a/tests/component/test_regression_eval.py b/tests/component/ml/test_regression_eval.py similarity index 100% rename from tests/component/test_regression_eval.py rename to tests/component/ml/test_regression_eval.py diff --git a/tests/component/test_sgb.py b/tests/component/ml/test_sgb.py similarity index 97% rename from tests/component/test_sgb.py rename to tests/component/ml/test_sgb.py index 7dfa1784d..afca2c085 100644 --- a/tests/component/test_sgb.py +++ b/tests/component/ml/test_sgb.py @@ -81,11 +81,13 @@ def get_pred_param(alice_path, bob_path, train_res, predict_path): "receiver", "save_ids", "save_label", + "batch_size", ], attrs=[ Attribute(s="alice"), Attribute(b=False), Attribute(b=True), + Attribute(i64=50), ], inputs=[ train_res.outputs[0], @@ -206,6 +208,11 @@ def test_sgb(comp_prod_sf_cluster_config): input_y = pd.read_csv(os.path.join(TEST_STORAGE_ROOT, "alice", alice_path)) output_y = pd.read_csv(os.path.join(TEST_STORAGE_ROOT, "alice", predict_path)) + output_it = IndividualTable() + + assert predict_res.outputs[0].meta.Unpack(output_it) + assert output_it.line_count == input_y.shape[0] + # label & pred assert output_y.shape[1] == 2 diff --git a/tests/component/test_ss_glm.py b/tests/component/ml/test_ss_glm.py similarity index 100% rename from tests/component/test_ss_glm.py rename to tests/component/ml/test_ss_glm.py diff --git a/tests/component/test_ss_sgd.py b/tests/component/ml/test_ss_sgd.py similarity index 100% rename from tests/component/test_ss_sgd.py rename to tests/component/ml/test_ss_sgd.py diff --git a/tests/component/test_ss_xgb.py b/tests/component/ml/test_ss_xgb.py similarity index 100% rename from tests/component/test_ss_xgb.py rename to tests/component/ml/test_ss_xgb.py diff --git a/tests/component/preprocessing/__init__.py b/tests/component/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/component/test_binary_op.py b/tests/component/preprocessing/test_binary_op.py similarity index 100% rename from tests/component/test_binary_op.py rename to tests/component/preprocessing/test_binary_op.py diff --git a/tests/component/test_case_when.py b/tests/component/preprocessing/test_case_when.py similarity index 100% rename from tests/component/test_case_when.py rename to tests/component/preprocessing/test_case_when.py diff --git a/tests/component/test_condition_filter.py b/tests/component/preprocessing/test_condition_filter.py similarity index 100% rename from tests/component/test_condition_filter.py rename to tests/component/preprocessing/test_condition_filter.py diff --git a/tests/component/preprocessing/test_feature_calculate.py b/tests/component/preprocessing/test_feature_calculate.py new file mode 100644 index 000000000..9d37eb274 --- /dev/null +++ b/tests/component/preprocessing/test_feature_calculate.py @@ -0,0 +1,491 @@ +import os + +import numpy as np + +import pandas as pd +from google.protobuf.json_format import MessageToJson + +from secretflow.component.data_utils import DistDataType +from secretflow.component.preprocessing.feature_calculate import feature_calculate +from secretflow.spec.extend.calculate_rules_pb2 import CalculateOpRules +from secretflow.spec.v1.component_pb2 import Attribute +from secretflow.spec.v1.data_pb2 import DistData, TableSchema, VerticalTable +from secretflow.spec.v1.evaluation_pb2 import NodeEvalParam +from sklearn.preprocessing import MinMaxScaler, StandardScaler + +from tests.conftest import TEST_STORAGE_ROOT + +test_data_alice = pd.DataFrame( + { + "a1": [i * (-0.8) for i in range(3)], + "a2": [0.1] * 3, + "a3": ["AAA", "BBB", "CCC"], + } +) + +test_data_bob = pd.DataFrame( + { + "b1": [i for i in range(3)], + } +) + + +def _almost_equal(df1, df2, rtol=1.0e-5): + try: + pd.testing.assert_frame_equal(df1, df2, rtol) + return True + except AssertionError: + return False + + +def _build_test(): + names = [] + tests = [] + features = [] + expected = [] + + # --------TEST STANDARDIZE--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.STANDARDIZE + scaler = StandardScaler() + + alice_data = test_data_alice.copy() + alice_feature = ['a1', 'a2'] + alice_data[alice_feature] = scaler.fit_transform(alice_data[alice_feature]) + alice_data = alice_data.reindex(['a3', 'a1', 'a2'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data[bob_feature] = scaler.fit_transform(bob_data[bob_feature]) + + names.append("standardize") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST NORMALIZATION--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.NORMALIZATION + scaler = MinMaxScaler() + + alice_data = test_data_alice.copy() + alice_feature = ['a1', 'a2'] + alice_data[alice_feature] = scaler.fit_transform(alice_data[alice_feature]) + alice_data = alice_data.reindex(['a3', 'a1', 'a2'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data[bob_feature] = scaler.fit_transform(bob_data[bob_feature]) + + names.append("normalization") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST RANGE_LIMIT--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.RANGE_LIMIT + rule.operands.extend(["1", "2"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data.loc[alice_data['a1'] < 1, 'a1'] = 1 + alice_data.loc[alice_data['a1'] > 2, 'a1'] = 2 + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data.loc[bob_data['b1'] < 1, 'b1'] = 1 + bob_data.loc[bob_data['b1'] > 2, 'b1'] = 2 + bob_data = bob_data.astype(float) + + names.append("range_limit") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST UNARY(+)--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.UNARY + rule.operands.extend(["+", "+", "1"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = alice_data['a1'] + 1.0 + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = bob_data['b1'] + 1.0 + + names.append("unary_+") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST UNARY(-)--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.UNARY + rule.operands.extend(["+", "-", "1"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = alice_data['a1'] - 1.0 + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = bob_data['b1'] - 1.0 + + names.append("unary_-") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST UNARY(reverse-)--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.UNARY + rule.operands.extend(["-", "-", "1"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = 1.0 - alice_data['a1'] + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = 1.0 - bob_data['b1'] + + names.append("unary_reverse_-") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST UNARY(*)--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.UNARY + rule.operands.extend(["+", "*", "2"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = alice_data['a1'] * 2.0 + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = bob_data['b1'] * 2.0 + + names.append("unary_*") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST UNARY(/)--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.UNARY + rule.operands.extend(["+", "/", "3"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = alice_data['a1'] / 3.0 + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = bob_data['b1'] / 3.0 + + names.append("unary_/") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST UNARY(reverse/)--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.UNARY + rule.operands.extend(["-", "/", "3"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = 3.0 / alice_data['a1'] + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = 3.0 / bob_data['b1'] + + names.append("unary_reverse_/") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST RECIPROCAL--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.RECIPROCAL + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = 1.0 / alice_data['a1'] + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = 1.0 / bob_data['b1'] + + names.append("reciprocal") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST ROUND--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.ROUND + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = alice_data['a1'].round() + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = bob_data['b1'].round() * 1.0 + + names.append("round") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST LOGROUND--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.LOG_ROUND + rule.operands.extend(["10"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = np.log2(alice_data['a1'] + 10) + alice_data['a1'] = alice_data['a1'].round() + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = np.log2(bob_data['b1'] + 10) + bob_data['b1'] = bob_data['b1'].round() + + names.append("log_round") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST SQRT--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.SQRT + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = np.sqrt(alice_data['a1']) + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = np.sqrt(bob_data['b1']) + + names.append("sqrt") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST LOG BASE E--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.LOG + rule.operands.extend(["e", "10"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = np.log(alice_data['a1'] + 10) + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = np.log(bob_data['b1'] + 10) + + names.append("log base e") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST LOG BASE 2--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.LOG + rule.operands.extend(["2", "10"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = np.log(alice_data['a1'] + 10) / np.log(2) + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = np.log(bob_data['b1'] + 10) / np.log(2) + + names.append("log base 2") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST EXP--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.EXP + + alice_data = test_data_alice.copy() + alice_feature = ['a1'] + alice_data['a1'] = np.exp(alice_data['a1']) + alice_data = alice_data.reindex(['a2', 'a3', 'a1'], axis=1) + + bob_data = test_data_bob.copy() + bob_feature = ['b1'] + bob_data['b1'] = np.exp(bob_data['b1']) + + names.append("exp") + tests.append(rule) + features.append(alice_feature + bob_feature) + expected.append((alice_data, bob_data)) + + # --------TEST LENGTH--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.LENGTH + + alice_data = test_data_alice.copy() + alice_feature = ['a3'] + alice_data['a3'] = alice_data['a3'].str.len() + alice_data = alice_data.reindex(['a1', 'a2', 'a3'], axis=1) + + names.append("length") + tests.append(rule) + features.append(alice_feature) + expected.append((alice_data, test_data_bob)) + + # --------TEST SUBSTR--------- + rule = CalculateOpRules() + rule.op = CalculateOpRules.OpType.SUBSTR + rule.operands.extend(["0", "2"]) + + alice_data = test_data_alice.copy() + alice_feature = ['a3'] + alice_data['a3'] = alice_data['a3'].str[:2] + alice_data = alice_data.reindex(['a1', 'a2', 'a3'], axis=1) + + names.append("substr") + tests.append(rule) + features.append(alice_feature) + expected.append((alice_data, test_data_bob)) + + return names, tests, features, expected + + +def test_feature_calculate(comp_prod_sf_cluster_config): + alice_input_path = "test_feature_calculate/alice.csv" + bob_input_path = "test_feature_calculate/bob.csv" + out_path = "test_feature_calculate/out.csv" + rule_path = "test_feature_calculate/feature_calculate.rule" + + storage_config, sf_cluster_config = comp_prod_sf_cluster_config + self_party = sf_cluster_config.private_config.self_party + local_fs_wd = storage_config.local_fs.wd + df_alice = pd.DataFrame() + if self_party == "alice": + df_alice = test_data_alice + os.makedirs( + os.path.join(local_fs_wd, "test_feature_calculate"), + exist_ok=True, + ) + + df_alice.to_csv( + os.path.join(local_fs_wd, alice_input_path), + index=False, + ) + elif self_party == "bob": + df_bob = test_data_bob + os.makedirs( + os.path.join(local_fs_wd, "test_feature_calculate"), + exist_ok=True, + ) + + df_bob.to_csv( + os.path.join(local_fs_wd, bob_input_path), + index=False, + ) + + param = NodeEvalParam( + domain="preprocessing", + name="feature_calculate", + version="0.0.1", + attr_paths=[ + "rules", + "input/in_ds/features", + ], + attrs=[ + Attribute(s="{}"), + Attribute(ss=[]), + ], + inputs=[ + DistData( + name="input_data", + type=str(DistDataType.VERTICAL_TABLE), + data_refs=[ + DistData.DataRef(uri=bob_input_path, party="bob", format="csv"), + DistData.DataRef(uri=alice_input_path, party="alice", format="csv"), + ], + ) + ], + output_uris=[ + out_path, + rule_path, + ], + ) + + meta = VerticalTable( + schemas=[ + TableSchema( + feature_types=["int32"], + features=["b1"], + ), + TableSchema( + feature_types=[ + "float64", + "float64", + "str", + ], + features=["a1", "a2", "a3"], + ), + ], + ) + + param.inputs[0].meta.Pack(meta) + + os.makedirs( + os.path.join(local_fs_wd, "test_feature_calculate"), + exist_ok=True, + ) + + for n, t, f, e in zip(*_build_test()): + param.attrs[0].s = MessageToJson(t) + param.attrs[1].ss[:] = f + + res = feature_calculate.eval( + param=param, + storage_config=storage_config, + cluster_config=sf_cluster_config, + ) + + assert len(res.outputs) == 2 + + alice_out = pd.read_csv(os.path.join(TEST_STORAGE_ROOT, "alice", out_path)) + assert _almost_equal( + alice_out, e[0] + ), f"{n}\n===out===\n{alice_out}\n===e===\n{e[0]}\n===r===\n{param.attrs[0].s}" + + bob_out = pd.read_csv(os.path.join(TEST_STORAGE_ROOT, "bob", out_path)) + assert _almost_equal( + bob_out, e[1] + ), f"{n}\n===out===\n{bob_out}\n===e===\n{e[1]}\n===r===\n{param.attrs[0].s}" + + assert len(res.outputs) == 2 diff --git a/tests/component/test_feature_filter.py b/tests/component/preprocessing/test_feature_filter.py similarity index 100% rename from tests/component/test_feature_filter.py rename to tests/component/preprocessing/test_feature_filter.py diff --git a/tests/component/test_fillna.py b/tests/component/preprocessing/test_fillna.py similarity index 93% rename from tests/component/test_fillna.py rename to tests/component/preprocessing/test_fillna.py index cf45e3f81..127ee8885 100644 --- a/tests/component/test_fillna.py +++ b/tests/component/preprocessing/test_fillna.py @@ -3,9 +3,10 @@ import numpy as np import pandas as pd +import pytest from secretflow.component.data_utils import DistDataType -from secretflow.component.preprocessing.fillna import fillna +from secretflow.component.preprocessing.fillna import fillna, SUPPORTED_FILL_NA_METHOD from secretflow.spec.v1.component_pb2 import Attribute from secretflow.spec.v1.data_pb2 import DistData, TableSchema, VerticalTable from secretflow.spec.v1.evaluation_pb2 import NodeEvalParam @@ -13,7 +14,8 @@ from tests.conftest import TEST_STORAGE_ROOT -def test_fillna(comp_prod_sf_cluster_config): +@pytest.mark.parametrize("strategy", SUPPORTED_FILL_NA_METHOD) +def test_fillna(comp_prod_sf_cluster_config, strategy): alice_input_path = "test_fillna/alice.csv" bob_input_path = "test_fillna/bob.csv" rule_path = "test_fillna/fillna.rule" @@ -72,7 +74,7 @@ def test_fillna(comp_prod_sf_cluster_config): 'input/input_dataset/fill_na_features', ], attrs=[ - Attribute(s="constant"), + Attribute(s=strategy), Attribute(f=99.0), Attribute(ss=["a2", "b4", "b5"]), ], diff --git a/tests/component/test_onehot_encode.py b/tests/component/preprocessing/test_onehot_encode.py similarity index 100% rename from tests/component/test_onehot_encode.py rename to tests/component/preprocessing/test_onehot_encode.py diff --git a/tests/component/test_psi.py b/tests/component/preprocessing/test_psi.py similarity index 100% rename from tests/component/test_psi.py rename to tests/component/preprocessing/test_psi.py diff --git a/tests/component/test_train_test_split.py b/tests/component/preprocessing/test_train_test_split.py similarity index 100% rename from tests/component/test_train_test_split.py rename to tests/component/preprocessing/test_train_test_split.py diff --git a/tests/component/test_vert_binning.py b/tests/component/preprocessing/test_vert_binning.py similarity index 100% rename from tests/component/test_vert_binning.py rename to tests/component/preprocessing/test_vert_binning.py diff --git a/tests/component/test_woe_binning.py b/tests/component/preprocessing/test_woe_binning.py similarity index 100% rename from tests/component/test_woe_binning.py rename to tests/component/preprocessing/test_woe_binning.py diff --git a/tests/component/stats/__init__.py b/tests/component/stats/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/component/test_groupby_statistics.py b/tests/component/stats/test_groupby_statistics.py similarity index 99% rename from tests/component/test_groupby_statistics.py rename to tests/component/stats/test_groupby_statistics.py index 43b0c6de8..a463b8fba 100644 --- a/tests/component/test_groupby_statistics.py +++ b/tests/component/stats/test_groupby_statistics.py @@ -67,7 +67,7 @@ def test_groupby_statistics(comp_prod_sf_cluster_config, by, value_agg_pairs): param = NodeEvalParam( domain="stats", name="groupby_statistics", - version="0.0.2", + version="0.0.3", attr_paths=["input/input_data/by", "aggregation_config"], attrs=[ Attribute(ss=by), diff --git a/tests/component/test_ss_pearsonr.py b/tests/component/stats/test_ss_pearsonr.py similarity index 100% rename from tests/component/test_ss_pearsonr.py rename to tests/component/stats/test_ss_pearsonr.py diff --git a/tests/component/test_ss_pvalue.py b/tests/component/stats/test_ss_pvalue.py similarity index 100% rename from tests/component/test_ss_pvalue.py rename to tests/component/stats/test_ss_pvalue.py diff --git a/tests/component/test_ss_vif.py b/tests/component/stats/test_ss_vif.py similarity index 100% rename from tests/component/test_ss_vif.py rename to tests/component/stats/test_ss_vif.py diff --git a/tests/component/test_table_statistics.py b/tests/component/stats/test_table_statistics.py similarity index 100% rename from tests/component/test_table_statistics.py rename to tests/component/stats/test_table_statistics.py diff --git a/tests/component/test_batch_reader.py b/tests/component/test_batch_reader.py new file mode 100644 index 000000000..7ccabe8e6 --- /dev/null +++ b/tests/component/test_batch_reader.py @@ -0,0 +1,71 @@ +import os + +import pandas as pd +from sklearn.datasets import load_breast_cancer +from sklearn.preprocessing import StandardScaler + +from secretflow import reveal, wait +from secretflow.component.batch_reader import SimpleVerticalBatchReader + + +def test_works(sf_production_setup_devices): + scaler = StandardScaler() + ds = load_breast_cancer() + x, y = scaler.fit_transform(ds["data"]), ds["target"] + + expected_row_cnt = x.shape[0] + cols = {} + paths = { + 'alice': os.path.join("tmp", "alice", "test_batch_reader", "alice.csv"), + 'bob': os.path.join("tmp", "bob", "test_batch_reader", "bob.csv"), + } + + def create_alice_data(p, x, y): + os.makedirs( + os.path.dirname(p), + exist_ok=True, + ) + x = pd.DataFrame(x[:, :15], columns=[f"a{i}" for i in range(15)]) + y = pd.DataFrame(y, columns=["y"]) + ds = pd.concat([x, y], axis=1) + ds.to_csv(p, index=False) + + wait(sf_production_setup_devices.alice(create_alice_data)(paths["alice"], x, y)) + + def create_bob_data(p, x): + os.makedirs( + os.path.dirname(p), + exist_ok=True, + ) + + ds = pd.DataFrame(x[:, 15:], columns=[f"b{i}" for i in range(15)]) + ds.to_csv(p, index=False) + + wait(sf_production_setup_devices.bob(create_bob_data)(paths["bob"], x)) + + cols = { + "alice": [f"a{i}" for i in range(15)], + "bob": [f"b{i}" for i in range(15)], + } + + reader = SimpleVerticalBatchReader( + paths=paths, + batch_size=50, + cols=cols, + ) + + row_cnt = 0 + for batches in reader: + alice_df, bob_df = reveal(batches["alice"]), reveal(batches["bob"]) + + assert alice_df.shape[0] == bob_df.shape[0] + + row_cnt += alice_df.shape[0] + + assert alice_df.shape[1] == 15 + + assert bob_df.shape[1] == 15 + + assert row_cnt == reader.total_read_cnt() + + assert expected_row_cnt == row_cnt diff --git a/tests/conftest.py b/tests/conftest.py index 5ec3c35fe..4ca2685e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,7 +158,7 @@ class DeviceInventory: def sf_memory_setup_devices(request): devices = DeviceInventory() sfd.set_distribution_mode(mode=DISTRIBUTION_MODE.DEBUG) - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) sf.init( ["alice", "bob", "carol", "davy", "spu"], debug_mode=True, @@ -177,14 +177,14 @@ def sf_memory_setup_devices(request): yield devices del devices - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) @pytest.fixture(scope="module", params=[semi2k_cluster]) def sf_simulation_setup_devices(request): devices = DeviceInventory() sfd.set_distribution_mode(mode=DISTRIBUTION_MODE.SIMULATION) - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) sf.init( ["alice", "bob", "carol", "davy"], address="local", @@ -212,7 +212,7 @@ def sf_simulation_setup_devices(request): yield devices del devices - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) @pytest.fixture(scope="session", params=SF_PARTIES) @@ -229,7 +229,7 @@ def sf_production_setup_devices_grpc(request, sf_party_for_4pc): address="local", num_cpus=32, log_to_driver=True, - logging_level='debug', + logging_level='info', cluster_config=cluster(), enable_waiting_for_other_parties_ready=False, ) @@ -265,7 +265,7 @@ def sf_production_setup_devices_grpc(request, sf_party_for_4pc): yield devices del devices - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) @pytest.fixture(scope="module") @@ -277,7 +277,7 @@ def sf_production_setup_devices(request, sf_party_for_4pc): address="local", num_cpus=32, log_to_driver=True, - logging_level='debug', + logging_level='info', cluster_config=cluster(), enable_waiting_for_other_parties_ready=False, cross_silo_comm_backend="brpc_link", @@ -323,7 +323,7 @@ def sf_production_setup_devices(request, sf_party_for_4pc): yield devices del devices - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) @pytest.fixture(scope="module") @@ -335,7 +335,7 @@ def sf_production_setup_devices_aby3(request, sf_party_for_4pc): address="local", num_cpus=32, log_to_driver=True, - logging_level='debug', + logging_level='info', cluster_config=cluster(), enable_waiting_for_other_parties_ready=False, cross_silo_comm_backend="brpc_link", @@ -381,7 +381,7 @@ def sf_production_setup_devices_aby3(request, sf_party_for_4pc): yield devices del devices - sf.shutdown() + sf.shutdown(barrier_on_shutdown=True) TEST_STORAGE_ROOT = os.path.join(tempfile.gettempdir(), getpass.getuser()) diff --git a/tests/device/test_teeu.py b/tests/device/test_teeu.py index 148f8f67a..2f35647f2 100644 --- a/tests/device/test_teeu.py +++ b/tests/device/test_teeu.py @@ -34,7 +34,7 @@ def teeu_production_setup_devices(request, sf_party_for_4pc): sf.init( address='local', cluster_config=cluster(), - logging_level='debug', + logging_level='info', num_cpus=8, log_to_driver=True, tee_simulation=True, diff --git a/tests/ml/boost/sgb_v/test_vert_sgb.py b/tests/ml/boost/sgb_v/test_vert_sgb.py index 138ee78d1..591df8696 100644 --- a/tests/ml/boost/sgb_v/test_vert_sgb.py +++ b/tests/ml/boost/sgb_v/test_vert_sgb.py @@ -37,8 +37,8 @@ def _run_sgb( subsample, colsample, audit_dict={}, - auc_bar=0.9, - mse_hat=1, + auc_bar=0.88, + mse_hat=1.1, tree_grow_method='level', enable_goss=False, early_stop_criterion_g_abs_sum=10.0, @@ -112,6 +112,11 @@ def _run_sgb( device: "./" + test_name + "/" + device.party for device in v_data.partitions.keys() } + label_holder_device = list(label_data.partitions.keys())[0] + if label_holder_device not in saving_path_dict: + saving_path_dict[label_holder_device] = ( + "./" + test_name + "/" + label_holder_device.party + ) model.save_model(saving_path_dict) model_loaded = load_model(saving_path_dict, env.alice) fed_yhat_loaded = model_loaded.predict(v_data, env.alice) @@ -124,7 +129,7 @@ def _run_sgb( ) -def _run_npc_linear(env, test_name, parts, label_device): +def _run_npc_linear(env, test_name, parts, label_device, auc=0.88): vdf = load_linear(parts=parts) label_data = vdf['y'] @@ -137,11 +142,23 @@ def _run_npc_linear(env, test_name, parts, label_device): label_data = label_data[:500, :] logging.info("running XGB style test") - _run_sgb(env, test_name, v_data, label_data, y, True, 0.9, 1) + _run_sgb(env, test_name, v_data, label_data, y, True, 0.9, 1, auc_bar=auc) logging.info("running lightGBM style test") # test with leaf wise growth and goss: lightGBM style _run_sgb( - env, test_name, v_data, label_data, y, True, 0.9, 1, {}, 0.9, 2.3, 'leaf', True + env, + test_name, + v_data, + label_data, + y, + True, + 0.9, + 1, + {}, + auc, + 2.3, + 'leaf', + True, ) @@ -173,6 +190,20 @@ def test_4pc_linear(sf_production_setup_devices_aby3): ) +def test_2pc_linear_minimal(sf_production_setup_devices_aby3): + parts = { + sf_production_setup_devices_aby3.davy: (1, 2), + sf_production_setup_devices_aby3.alice: (21, 22), + } + _run_npc_linear( + sf_production_setup_devices_aby3, + "2pc_linear_minimal", + parts, + sf_production_setup_devices_aby3.alice, + auc=0.55, + ) + + def test_breast_cancer(sf_production_setup_devices_aby3): from sklearn.datasets import load_breast_cancer diff --git a/tests/ml/boost/ss_xgb_v/test_vert_ss_xgb.py b/tests/ml/boost/ss_xgb_v/test_vert_ss_xgb.py index 413917c23..4865a2e19 100644 --- a/tests/ml/boost/ss_xgb_v/test_vert_ss_xgb.py +++ b/tests/ml/boost/ss_xgb_v/test_vert_ss_xgb.py @@ -15,13 +15,13 @@ import os import time -from sklearn.metrics import mean_squared_error, roc_auc_score - from secretflow.data import FedNdarray, PartitionWay from secretflow.device.driver import reveal, wait from secretflow.ml.boost.ss_xgb_v import Xgb from secretflow.utils.simulation.datasets import load_dermatology, load_linear +from sklearn.metrics import mean_squared_error, roc_auc_score + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' @@ -108,21 +108,6 @@ def test_3pc_linear(sf_production_setup_devices_aby3): ) -def test_4pc_linear(sf_production_setup_devices_aby3): - parts = { - sf_production_setup_devices_aby3.alice: (1, 6), - sf_production_setup_devices_aby3.bob: (6, 12), - sf_production_setup_devices_aby3.carol: (12, 18), - sf_production_setup_devices_aby3.davy: (18, 22), - } - _run_npc_linear( - sf_production_setup_devices_aby3, - "4pc_linear", - parts, - sf_production_setup_devices_aby3.davy, - ) - - def test_breast_cancer(sf_production_setup_devices_aby3): from sklearn.datasets import load_breast_cancer diff --git a/tests/ml/nn/sl/applications/test_sl_model_torch_cv.py b/tests/ml/nn/sl/applications/test_sl_model_torch_cv.py new file mode 100644 index 000000000..01fe55169 --- /dev/null +++ b/tests/ml/nn/sl/applications/test_sl_model_torch_cv.py @@ -0,0 +1,87 @@ +import torch +from torch import nn, optim +from torch.utils.data import DataLoader, Dataset + +from secretflow.ml.nn.applications.sl_resnet_torch import ( + BasicBlock, + ResNetBase, + ResNetFuse, +) +from secretflow.ml.nn.applications.sl_vgg_torch import VGGBase, VGGFuse +from secretflow.ml.nn.utils import BaseModule + + +class SimSLVGG16(BaseModule): + def __init__(self): + super(SimSLVGG16, self).__init__() + self.alice_base = VGGBase() + self.bob_base = VGGBase() + self.fuse = VGGFuse() + + def forward(self, x): + alice_hid = self.alice_base(x[0]) + bob_hid = self.bob_base(x[1]) + out = self.fuse((alice_hid, bob_hid)) + return out + + +class SimSLResNet18(BaseModule): + def __init__(self): + super(SimSLResNet18, self).__init__() + self.alice_base = ResNetBase(BasicBlock, [2, 2, 2, 2]) + self.bob_base = ResNetBase(BasicBlock, [2, 2, 2, 2]) + self.fuse = ResNetFuse() + + def forward(self, x): + alice_hid = self.alice_base(x[0]) + bob_hid = self.bob_base(x[1]) + out = self.fuse((alice_hid, bob_hid)) + return out + + +class CustomDataset(Dataset): + def __init__(self, data, labels, split_idx): + self.data = data + self.labels = labels + self.split_idx = split_idx + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + data_alice = self.data[index][..., : self.split_idx] + data_bob = self.data[index][..., self.split_idx :] + label = self.labels[index] + return (data_alice, data_bob), label + + +def simulate_sl_model_training(model): + data_num = 20 + data = torch.randn(data_num, 3, 32, 32) + labels = torch.randint(0, 10, (data_num,)) + + custom_dataset = CustomDataset(data, labels, split_idx=16) + + batch_size = 4 + data_loader = DataLoader( + dataset=custom_dataset, batch_size=batch_size, shuffle=True + ) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters()) + + for batch_id, (data, label) in enumerate(data_loader): + optimizer.zero_grad() + output = model(data) + loss = criterion(output, label) + loss.backward() + optimizer.step() + print(f'batch {batch_id}, loss: {loss}') + + +def test_sl_cv_model(): + vgg = SimSLVGG16() + simulate_sl_model_training(vgg) + + resnet = SimSLResNet18() + simulate_sl_model_training(resnet)