diff --git a/examples/advanced/finance-end-to-end/utils/merge_feat.py b/examples/advanced/finance-end-to-end/utils/merge_feat.py new file mode 100644 index 0000000000..a7eb2df11e --- /dev/null +++ b/examples/advanced/finance-end-to-end/utils/merge_feat.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +import pandas as pd + +files = ["train", "test"] + +bic_to_bank = { + "ZHSZUS33": "Bank_1", + "SHSHKHH1": "Bank_2", + "YXRXGB22": "Bank_3", + "WPUWDEFF": "Bank_4", + "YMNYFRPP": "Bank_5", + "FBSFCHZH": "Bank_6", + "YSYCESMM": "Bank_7", + "ZNZZAU3M": "Bank_8", + "HCBHSGSG": "Bank_9", + "XITXUS33": "Bank_10", +} + +original_columns = [ + "UETR", + "Timestamp", + "Amount", + "trans_volume", + "total_amount", + "average_amount", + "hist_trans_volume", + "hist_total_amount", + "hist_average_amount", + "x2_y1", + "x3_y2", +] + + +def main(): + args = define_parser() + root_path = args.input_dir + original_feat_postfix = "_normalized.csv" + embed_feat_postfix = "_embedding.csv" + out_feat_postfix = "_combined.csv" + + for bic in bic_to_bank.keys(): + print("Processing BIC: ", bic) + for file in files: + original_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + original_feat_postfix) + embed_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + embed_feat_postfix) + out_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + out_feat_postfix) + + # Load the original and embedding features + original_feat = pd.read_csv(original_feat_file) + embed_feat = pd.read_csv(embed_feat_file) + + # Select the columns of the original features + original_feat = original_feat[original_columns] + + # Combine the features, matching the rows by "UETR" + out_feat = pd.merge(original_feat, embed_feat, on="UETR") + + # Save the combined features + out_feat.to_csv(out_feat_file, index=False) + + +def define_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input_dir", + type=str, + nargs="?", + default="/tmp/nvflare/xgb/credit_card", + help="output directory, default to '/tmp/nvflare/xgb/credit_card'", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/finance-end-to-end/xgboost.ipynb b/examples/advanced/finance-end-to-end/xgboost.ipynb index eca92a0d88..d740f44fb2 100644 --- a/examples/advanced/finance-end-to-end/xgboost.ipynb +++ b/examples/advanced/finance-end-to-end/xgboost.ipynb @@ -128,7 +128,9 @@ "Similarily, we define job scripts on server-side to trigger and coordinate running client-side scripts on each site: \n", "\n", "* [graph_construction_job.py](./nvflare/graph_construct_job.py)\n", - "* [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py)" + "* [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py)\n", + "\n", + "The resulting GNN encodings will be merged with the normalized data for enhancing the feature." ] }, { @@ -144,7 +146,8 @@ "\n", "To specify the controller and executor, we need to define a Job. You can find the job construction in\n", "\n", - "* [xgb_job.py](./nvflare/xgb_job.py). \n", + "* [xgb_job.py](./nvflare/xgb_job.py)\n", + "* [xgb_job_embed.py](./nvflare/xgb_job_embed.py)\n", "\n", "Below is main part of the code\n", "\n", @@ -338,12 +341,51 @@ "%cd .." ] }, + { + "cell_type": "markdown", + "id": "c201d73d-d0eb-4691-b4f6-f4b930168ef2", + "metadata": {}, + "source": [ + "### GNN Training and Encoding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "337f64e8-072e-4f6f-8698-f513ecb47f47", + "metadata": {}, + "outputs": [], + "source": [ + "%cd nvflare\n", + "! python gnn_train_encode_job.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -p gnn_train_encode.py -a \"-i /tmp/nvflare/xgb/credit_card -o /tmp/nvflare/xgb/credit_card/\"\n", + "%cd .." + ] + }, + { + "cell_type": "markdown", + "id": "fb6484fc-e226-4b1b-bc79-afbae4d2b918", + "metadata": {}, + "source": [ + "### GNN Encoding Merge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3c8adaf-0644-4534-85ac-d83454b9ac8d", + "metadata": {}, + "outputs": [], + "source": [ + "! python3 ./utils/merge_feat.py" + ] + }, { "cell_type": "markdown", "id": "aae5236b-0f40-4b91-9fc2-4f2836b52537", "metadata": {}, "source": [ - "### Run XGBoost Job" + "### Run XGBoost Job\n", + "#### Without GNN embeddings" ] }, { @@ -360,6 +402,28 @@ "%cd .." ] }, + { + "cell_type": "markdown", + "id": "89cc1faa-bc20-4c78-8bb2-19880e98723f", + "metadata": {}, + "source": [ + "#### With GNN embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8530b8f8-5877-4c86-b72d-e8adacbad35a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%cd nvflare\n", + "! python xgb_job_embed.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -i /tmp/nvflare/xgb/credit_card -w /tmp/nvflare/workspace/xgb/credit_card_embed\n", + "%cd .." + ] + }, { "cell_type": "markdown", "id": "50a9090a-d50a-46d7-bc8f-388717d18f96",