Skip to content

Commit

Permalink
Add Step by step fed stats tabular example (#2151)
Browse files Browse the repository at this point in the history
* WIP

* add step-by-step df_stats example

* fix grammars

* fix grammars
  • Loading branch information
chesterxgchen authored Nov 17, 2023
1 parent a2c4c04 commit 64666c9
Show file tree
Hide file tree
Showing 6 changed files with 1,152 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"\n",
"Follow [Getting Started](https://nvflare.readthedocs.io/en/main/getting_started.html) to set up a virtual environment and install NVFLARE.\n",
"\n",
"You can also follow this [notebook](../../nvflare_setup.ipynb) to get set up.\n",
"You can also follow this [notebook](https://github.com/NVIDIA/NVFlare/blob/main/examples/nvflare_setup.ipynb) to get set up.\n",
"\n",
"> Make sure you have installed nvflare from **terminal** \n"
]
Expand Down Expand Up @@ -421,7 +421,7 @@
"metadata": {},
"source": [
"## Create Federated Histogram Job\n",
"We are going to use NVFLARE job cli to create job. For detailed instructions on Job CLI, please follow the [job cli tutorial](https://github.com/NVIDIA/NVFlare/blob/eee330a23c2efd9a9eea37415259146d83f8d52f/examples/tutorials/job_cli.ipynb).\n",
"We are going to use NVFLARE job cli to create job. For detailed instructions on Job CLI, please follow the [job cli tutorial](https://github.com/NVIDIA/NVFlare/blob/main/examples/tutorials/job_cli.ipynb).\n",
"\n",
"Let's check the available job templates, we are going to use one of the existing job template and modify to fit our needs. \n",
"The job template is nothing but server and client-side job configurations.\n"
Expand Down Expand Up @@ -651,8 +651,9 @@
"The global and local histograms differences are none as we are using the same dataset for all clients. \n",
"\n",
"## We are done !\n",
"Congratulations, you just completed the federated stats image histogram calulation. \n",
"If you want to see another example of federated statistics calculations and configurations, please checkout [federated_statistics](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/federated-statistics) and [fed_stats with spleen_ct_segmentation](https://github.com/NVIDIA/NVFlare/tree/main/integration/monai/examples/spleen_ct_segmentation_sim)\n",
"Congratulations! you have just completed the federated stats image histogram calulation. \n",
"\n",
"If you would like to see another example of federated statistics calculations and configurations, please checkout [federated_statistics](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/federated-statistics) and [fed_stats with spleen_ct_segmentation](https://github.com/NVIDIA/NVFlare/tree/main/integration/monai/examples/spleen_ct_segmentation_sim)\n",
"\n",
"Let's move on to the next example and see how can we train the image classifier using pytorch with CIFAR10 data.\n",
"\n",
Expand Down
135 changes: 135 additions & 0 deletions examples/hello-world/step-by-step/higgs/stats/code/df_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
from pandas.core.series import Series

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics
from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets


class DFStatistics(Statistics):
def __init__(self, data_root_dir: str):
super().__init__()
self.data_root_dir = data_root_dir
self.data: Optional[Dict[str, pd.DataFrame]] = None
self.data_features = None

def load_features(self, fl_ctx: FLContext) -> List:
client_name = self.get_client_name(fl_ctx)
try:
data_path = f"{self.data_root_dir}/{client_name}_header.csv"

features = []
with open(data_path, "r") as file:
# Create a CSV reader object
csv_reader = csv.reader(file)
line_list = next(csv_reader)
features = line_list
return features
except Exception as e:
raise Exception(f"Load header for client {client_name} failed! {e}")

def load_data(self, fl_ctx: FLContext) -> Dict[str, pd.DataFrame]:
client_name = self.get_client_name(fl_ctx)
try:
data_path = f"{self.data_root_dir}/{client_name}.csv"
# example of load data from CSV
df: pd.DataFrame = pd.read_csv(
data_path, names=self.data_features, sep=r"\s*,\s*", engine="python", na_values="?"
)
return {"train": df}

except Exception as e:
raise Exception(f"Load data for client {client_name} failed! {e}")

def get_client_name(self, fl_ctx):
client_name = fl_ctx.get_identity_name() if fl_ctx is not None else "site-1"
if fl_ctx:
self.log_info(fl_ctx, f"load data for client {client_name}")
else:
print(f"load data for client {client_name}")
return client_name

def initialize(self, fl_ctx: FLContext):
self.data_features = self.load_features(fl_ctx)
self.data = self.load_data(fl_ctx)
if self.data is None:
raise ValueError("data is not loaded. make sure the data is loaded")

def features(self) -> Dict[str, List[Feature]]:
results: Dict[str, List[Feature]] = {}
for ds_name in self.data:
df = self.data[ds_name]
results[ds_name] = []
for feature_name in df:
data_type = dtype_to_data_type(df[feature_name].dtype)
results[ds_name].append(Feature(feature_name, data_type))

return results

def count(self, dataset_name: str, feature_name: str) -> int:
df: pd.DataFrame = self.data[dataset_name]
return df[feature_name].count()

def sum(self, dataset_name: str, feature_name: str) -> float:
df: pd.DataFrame = self.data[dataset_name]
return df[feature_name].sum().item()

def mean(self, dataset_name: str, feature_name: str) -> float:

count: int = self.count(dataset_name, feature_name)
sum_value: float = self.sum(dataset_name, feature_name)
return sum_value / count

def stddev(self, dataset_name: str, feature_name: str) -> float:
df = self.data[dataset_name]
return df[feature_name].std().item()

def variance_with_mean(
self, dataset_name: str, feature_name: str, global_mean: float, global_count: float
) -> float:
df = self.data[dataset_name]
tmp = (df[feature_name] - global_mean) * (df[feature_name] - global_mean)
variance = tmp.sum() / (global_count - 1)
return variance.item()

def histogram(
self, dataset_name: str, feature_name: str, num_of_bins: int, global_min_value: float, global_max_value: float
) -> Histogram:

num_of_bins: int = num_of_bins

df = self.data[dataset_name]
feature: Series = df[feature_name]
flattened = feature.ravel()
flattened = flattened[flattened != np.array(None)]
buckets = get_std_histogram_buckets(flattened, num_of_bins, BinRange(global_min_value, global_max_value))
return Histogram(HistogramType.STANDARD, buckets)

def max_value(self, dataset_name: str, feature_name: str) -> float:
"""this is needed for histogram calculation, not used for reporting"""

df = self.data[dataset_name]
return df[feature_name].max()

def min_value(self, dataset_name: str, feature_name: str) -> float:
"""this is needed for histogram calculation, not used for reporting"""

df = self.data[dataset_name]
return df[feature_name].min()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
nvflare>=2.3.0
numpy
pandas
matplotlib
jupyterlab
81 changes: 81 additions & 0 deletions examples/hello-world/step-by-step/higgs/stats/split_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import os
import shutil

import pandas as pd


def split_csv(input_file_path, output_dir, num_parts, part_name, sample_rate):
# Read the CSV file into a pandas DataFrame
df = pd.read_csv(input_file_path)

# Calculate the number of rows per part
total_size = int(len(df) * sample_rate)
rows_per_part = total_size // num_parts

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Split the DataFrame into N parts
for i in range(num_parts):
start_index = i * rows_per_part
end_index = (i + 1) * rows_per_part if i < num_parts - 1 else total_size
part_df = df.iloc[start_index:end_index]

# Save each part to a separate CSV file
output_file = os.path.join(output_dir, f"{part_name}{i + 1}.csv")
part_df.to_csv(output_file, index=False)


def distribute_header_file(input_header_file: str, output_dir: str, num_parts: int, part_name: str):
source_file = input_header_file

# Split the DataFrame into N parts
for i in range(num_parts):
output_file = os.path.join(output_dir, f"{part_name}{i + 1}_header.csv")
shutil.copy(source_file, output_file)
print(f"File copied to {output_file}")


def define_args_parser():
parser = argparse.ArgumentParser(description="csv data split")
parser.add_argument("--input_data_path", type=str, help="input path to csv data file")
parser.add_argument("--input_header_path", type=str, help="input path to csv header file")
parser.add_argument("--site_num", type=int, help="Total number of sites or clients")
parser.add_argument("--site_name_prefix", type=str, default="site-", help="Site name prefix")
parser.add_argument("--output_dir", type=str, default="/tmp/nvflare/dataset/output", help="Output directory")
parser.add_argument(
"--sample_rate", type=float, default="1.0", help="percent of the data will be used. default 1.0 for 100%"
)
return parser


def main():
parser = define_args_parser()
args = parser.parse_args()
input_file = args.input_data_path
output_directory = args.output_dir
num_parts = args.site_num
site_name_prefix = args.site_name_prefix
sample_rate = args.sample_rate
split_csv(input_file, output_directory, num_parts, site_name_prefix, sample_rate)
distribute_header_file(args.input_header_path, output_directory, num_parts, site_name_prefix)


if __name__ == "__main__":
main()
Loading

0 comments on commit 64666c9

Please sign in to comment.