-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Step by step fed stats tabular example (#2151)
* WIP * add step-by-step df_stats example * fix grammars * fix grammars
- Loading branch information
1 parent
a2c4c04
commit 64666c9
Showing
6 changed files
with
1,152 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
examples/hello-world/step-by-step/higgs/stats/code/df_stats.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import csv | ||
from typing import Dict, List, Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from pandas.core.series import Series | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics | ||
from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets | ||
|
||
|
||
class DFStatistics(Statistics): | ||
def __init__(self, data_root_dir: str): | ||
super().__init__() | ||
self.data_root_dir = data_root_dir | ||
self.data: Optional[Dict[str, pd.DataFrame]] = None | ||
self.data_features = None | ||
|
||
def load_features(self, fl_ctx: FLContext) -> List: | ||
client_name = self.get_client_name(fl_ctx) | ||
try: | ||
data_path = f"{self.data_root_dir}/{client_name}_header.csv" | ||
|
||
features = [] | ||
with open(data_path, "r") as file: | ||
# Create a CSV reader object | ||
csv_reader = csv.reader(file) | ||
line_list = next(csv_reader) | ||
features = line_list | ||
return features | ||
except Exception as e: | ||
raise Exception(f"Load header for client {client_name} failed! {e}") | ||
|
||
def load_data(self, fl_ctx: FLContext) -> Dict[str, pd.DataFrame]: | ||
client_name = self.get_client_name(fl_ctx) | ||
try: | ||
data_path = f"{self.data_root_dir}/{client_name}.csv" | ||
# example of load data from CSV | ||
df: pd.DataFrame = pd.read_csv( | ||
data_path, names=self.data_features, sep=r"\s*,\s*", engine="python", na_values="?" | ||
) | ||
return {"train": df} | ||
|
||
except Exception as e: | ||
raise Exception(f"Load data for client {client_name} failed! {e}") | ||
|
||
def get_client_name(self, fl_ctx): | ||
client_name = fl_ctx.get_identity_name() if fl_ctx is not None else "site-1" | ||
if fl_ctx: | ||
self.log_info(fl_ctx, f"load data for client {client_name}") | ||
else: | ||
print(f"load data for client {client_name}") | ||
return client_name | ||
|
||
def initialize(self, fl_ctx: FLContext): | ||
self.data_features = self.load_features(fl_ctx) | ||
self.data = self.load_data(fl_ctx) | ||
if self.data is None: | ||
raise ValueError("data is not loaded. make sure the data is loaded") | ||
|
||
def features(self) -> Dict[str, List[Feature]]: | ||
results: Dict[str, List[Feature]] = {} | ||
for ds_name in self.data: | ||
df = self.data[ds_name] | ||
results[ds_name] = [] | ||
for feature_name in df: | ||
data_type = dtype_to_data_type(df[feature_name].dtype) | ||
results[ds_name].append(Feature(feature_name, data_type)) | ||
|
||
return results | ||
|
||
def count(self, dataset_name: str, feature_name: str) -> int: | ||
df: pd.DataFrame = self.data[dataset_name] | ||
return df[feature_name].count() | ||
|
||
def sum(self, dataset_name: str, feature_name: str) -> float: | ||
df: pd.DataFrame = self.data[dataset_name] | ||
return df[feature_name].sum().item() | ||
|
||
def mean(self, dataset_name: str, feature_name: str) -> float: | ||
|
||
count: int = self.count(dataset_name, feature_name) | ||
sum_value: float = self.sum(dataset_name, feature_name) | ||
return sum_value / count | ||
|
||
def stddev(self, dataset_name: str, feature_name: str) -> float: | ||
df = self.data[dataset_name] | ||
return df[feature_name].std().item() | ||
|
||
def variance_with_mean( | ||
self, dataset_name: str, feature_name: str, global_mean: float, global_count: float | ||
) -> float: | ||
df = self.data[dataset_name] | ||
tmp = (df[feature_name] - global_mean) * (df[feature_name] - global_mean) | ||
variance = tmp.sum() / (global_count - 1) | ||
return variance.item() | ||
|
||
def histogram( | ||
self, dataset_name: str, feature_name: str, num_of_bins: int, global_min_value: float, global_max_value: float | ||
) -> Histogram: | ||
|
||
num_of_bins: int = num_of_bins | ||
|
||
df = self.data[dataset_name] | ||
feature: Series = df[feature_name] | ||
flattened = feature.ravel() | ||
flattened = flattened[flattened != np.array(None)] | ||
buckets = get_std_histogram_buckets(flattened, num_of_bins, BinRange(global_min_value, global_max_value)) | ||
return Histogram(HistogramType.STANDARD, buckets) | ||
|
||
def max_value(self, dataset_name: str, feature_name: str) -> float: | ||
"""this is needed for histogram calculation, not used for reporting""" | ||
|
||
df = self.data[dataset_name] | ||
return df[feature_name].max() | ||
|
||
def min_value(self, dataset_name: str, feature_name: str) -> float: | ||
"""this is needed for histogram calculation, not used for reporting""" | ||
|
||
df = self.data[dataset_name] | ||
return df[feature_name].min() |
5 changes: 5 additions & 0 deletions
5
examples/hello-world/step-by-step/higgs/stats/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
nvflare>=2.3.0 | ||
numpy | ||
pandas | ||
matplotlib | ||
jupyterlab |
81 changes: 81 additions & 0 deletions
81
examples/hello-world/step-by-step/higgs/stats/split_csv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import argparse | ||
import os | ||
import shutil | ||
|
||
import pandas as pd | ||
|
||
|
||
def split_csv(input_file_path, output_dir, num_parts, part_name, sample_rate): | ||
# Read the CSV file into a pandas DataFrame | ||
df = pd.read_csv(input_file_path) | ||
|
||
# Calculate the number of rows per part | ||
total_size = int(len(df) * sample_rate) | ||
rows_per_part = total_size // num_parts | ||
|
||
# Create the output directory if it doesn't exist | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
# Split the DataFrame into N parts | ||
for i in range(num_parts): | ||
start_index = i * rows_per_part | ||
end_index = (i + 1) * rows_per_part if i < num_parts - 1 else total_size | ||
part_df = df.iloc[start_index:end_index] | ||
|
||
# Save each part to a separate CSV file | ||
output_file = os.path.join(output_dir, f"{part_name}{i + 1}.csv") | ||
part_df.to_csv(output_file, index=False) | ||
|
||
|
||
def distribute_header_file(input_header_file: str, output_dir: str, num_parts: int, part_name: str): | ||
source_file = input_header_file | ||
|
||
# Split the DataFrame into N parts | ||
for i in range(num_parts): | ||
output_file = os.path.join(output_dir, f"{part_name}{i + 1}_header.csv") | ||
shutil.copy(source_file, output_file) | ||
print(f"File copied to {output_file}") | ||
|
||
|
||
def define_args_parser(): | ||
parser = argparse.ArgumentParser(description="csv data split") | ||
parser.add_argument("--input_data_path", type=str, help="input path to csv data file") | ||
parser.add_argument("--input_header_path", type=str, help="input path to csv header file") | ||
parser.add_argument("--site_num", type=int, help="Total number of sites or clients") | ||
parser.add_argument("--site_name_prefix", type=str, default="site-", help="Site name prefix") | ||
parser.add_argument("--output_dir", type=str, default="/tmp/nvflare/dataset/output", help="Output directory") | ||
parser.add_argument( | ||
"--sample_rate", type=float, default="1.0", help="percent of the data will be used. default 1.0 for 100%" | ||
) | ||
return parser | ||
|
||
|
||
def main(): | ||
parser = define_args_parser() | ||
args = parser.parse_args() | ||
input_file = args.input_data_path | ||
output_directory = args.output_dir | ||
num_parts = args.site_num | ||
site_name_prefix = args.site_name_prefix | ||
sample_rate = args.sample_rate | ||
split_csv(input_file, output_directory, num_parts, site_name_prefix, sample_rate) | ||
distribute_header_file(args.input_header_path, output_directory, num_parts, site_name_prefix) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.