Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the GNN encoding for XGB financial example #3039

Merged
merged 15 commits into from
Oct 14, 2024
Merged
21 changes: 13 additions & 8 deletions examples/advanced/finance-end-to-end/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ Please refer to the scripts:
- [graph_construct.py](./nvflare/graph_construct.py) and [graph_construct_job.py](./nvflare/graph_construct_job.py) for graph construction
- [gnn_train_encode.py](./nvflare/gnn_train_encode.py) and [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py) for GNN training and encoding

The resulting GNN encodings will be merged with the normalized data for enhancing the feature.

## Step 3: Federated Training of XGBoost
Now we have enriched / encoded features, the last step is to run federated XGBoost over them.
Below is the xgboost job code
Expand Down Expand Up @@ -285,7 +287,7 @@ it anywhere in a real deployment.

Assuming you have already downloaded the credit card dataset and the creditcard.csv file is located in the current directory:

* prepare data
* Prepare data
```
python ./utils/prepare_data.py -i ./creditcard.csv -o /tmp/nvflare/xgb/credit_card
```
Expand All @@ -302,21 +304,21 @@ python ./utils/prepare_data.py -i ./creditcard.csv -o /tmp/nvflare/xgb/credit_ca
> * 'XITXUS33_Bank_10'
> Total 10 banks
* enrich data
* Enrich data
```
cd nvflare
python enrich_job.py -c 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'FBSFCHZH_Bank_6' 'YMNYFRPP_Bank_5' 'WPUWDEFF_Bank_4' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'YSYCESMM_Bank_7' 'ZHSZUS33_Bank_1' 'HCBHSGSG_Bank_9' -p enrich.py -a "-i /tmp/nvflare/xgb/credit_card/ -o /tmp/nvflare/xgb/credit_card/"
cd ..
```

* pre-process data
* Pre-process data
```
cd nvflare
python pre_process_job.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -p pre_process.py -a "-i /tmp/nvflare/xgb/credit_card -o /tmp/nvflare/xgb/credit_card/"
cd ..
```

* construct graph
* Construct graph
```
cd nvflare
python graph_construct_job.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -p graph_construct.py -a "-i /tmp/nvflare/xgb/credit_card -o /tmp/nvflare/xgb/credit_card/"
Expand All @@ -330,6 +332,10 @@ python gnn_train_encode_job.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_
cd ..
```

* Add GNN embeddings to the normalized data
```
python3 utils/merge_feat.py
```

* XGBoost Job

Expand All @@ -343,7 +349,7 @@ cd ..
Below is the output of last round of training (starting round = 0)
```
...
[9] eval-auc:0.67596 train-auc:0.70582
[9] eval-auc:0.69383 train-auc:0.71165
```
For GNN embeddings, we run the following command
```
Expand All @@ -354,7 +360,6 @@ cd ..
Below is the output of last round of training (starting round = 0)
```
...
[9] eval-auc:0.53788 train-auc:0.61659
[9] eval-auc:0.72318 train-auc:0.72241
```
For this example, the normalized data performs better than the GNN embeddings. This is expected as the GNN embeddings are produced with randomly generated transactional information, which adds noise to the data.

As shown, GNN embeddings help to promote the model performance by providing extra features beyond the hand-crafted ones.
95 changes: 95 additions & 0 deletions examples/advanced/finance-end-to-end/utils/merge_feat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import os

import pandas as pd

files = ["train", "test"]

bic_to_bank = {
"ZHSZUS33": "Bank_1",
"SHSHKHH1": "Bank_2",
"YXRXGB22": "Bank_3",
"WPUWDEFF": "Bank_4",
"YMNYFRPP": "Bank_5",
"FBSFCHZH": "Bank_6",
"YSYCESMM": "Bank_7",
"ZNZZAU3M": "Bank_8",
"HCBHSGSG": "Bank_9",
"XITXUS33": "Bank_10",
}

original_columns = [
"UETR",
"Timestamp",
"Amount",
"trans_volume",
"total_amount",
"average_amount",
"hist_trans_volume",
"hist_total_amount",
"hist_average_amount",
"x2_y1",
"x3_y2",
]


def main():
args = define_parser()
root_path = args.input_dir
original_feat_postfix = "_normalized.csv"
embed_feat_postfix = "_embedding.csv"
out_feat_postfix = "_combined.csv"

for bic in bic_to_bank.keys():
print("Processing BIC: ", bic)
for file in files:
original_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + original_feat_postfix)
embed_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + embed_feat_postfix)
out_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + out_feat_postfix)

# Load the original and embedding features
original_feat = pd.read_csv(original_feat_file)
embed_feat = pd.read_csv(embed_feat_file)

# Select the columns of the original features
original_feat = original_feat[original_columns]

# Combine the features, matching the rows by "UETR"
out_feat = pd.merge(original_feat, embed_feat, on="UETR")

# Save the combined features
out_feat.to_csv(out_feat_file, index=False)


def define_parser():
parser = argparse.ArgumentParser()

parser.add_argument(
"-i",
"--input_dir",
type=str,
nargs="?",
default="/tmp/nvflare/xgb/credit_card",
help="output directory, default to '/tmp/nvflare/xgb/credit_card'",
)

return parser.parse_args()


if __name__ == "__main__":
main()
70 changes: 67 additions & 3 deletions examples/advanced/finance-end-to-end/xgboost.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@
"Similarily, we define job scripts on server-side to trigger and coordinate running client-side scripts on each site: \n",
"\n",
"* [graph_construction_job.py](./nvflare/graph_construct_job.py)\n",
"* [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py)"
"* [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py)\n",
"\n",
"The resulting GNN encodings will be merged with the normalized data for enhancing the feature."
]
},
{
Expand All @@ -144,7 +146,8 @@
"\n",
"To specify the controller and executor, we need to define a Job. You can find the job construction in\n",
"\n",
"* [xgb_job.py](./nvflare/xgb_job.py). \n",
"* [xgb_job.py](./nvflare/xgb_job.py)\n",
"* [xgb_job_embed.py](./nvflare/xgb_job_embed.py)\n",
"\n",
"Below is main part of the code\n",
"\n",
Expand Down Expand Up @@ -338,12 +341,51 @@
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "c201d73d-d0eb-4691-b4f6-f4b930168ef2",
"metadata": {},
"source": [
"### GNN Training and Encoding"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "337f64e8-072e-4f6f-8698-f513ecb47f47",
"metadata": {},
"outputs": [],
"source": [
"%cd nvflare\n",
"! python gnn_train_encode_job.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -p gnn_train_encode.py -a \"-i /tmp/nvflare/xgb/credit_card -o /tmp/nvflare/xgb/credit_card/\"\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "fb6484fc-e226-4b1b-bc79-afbae4d2b918",
"metadata": {},
"source": [
"### GNN Encoding Merge"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3c8adaf-0644-4534-85ac-d83454b9ac8d",
"metadata": {},
"outputs": [],
"source": [
"! python3 ./utils/merge_feat.py"
]
},
{
"cell_type": "markdown",
"id": "aae5236b-0f40-4b91-9fc2-4f2836b52537",
"metadata": {},
"source": [
"### Run XGBoost Job"
"### Run XGBoost Job\n",
"#### Without GNN embeddings"
]
},
{
Expand All @@ -360,6 +402,28 @@
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "89cc1faa-bc20-4c78-8bb2-19880e98723f",
"metadata": {},
"source": [
"#### With GNN embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8530b8f8-5877-4c86-b72d-e8adacbad35a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%cd nvflare\n",
"! python xgb_job_embed.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -i /tmp/nvflare/xgb/credit_card -w /tmp/nvflare/workspace/xgb/credit_card_embed\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "50a9090a-d50a-46d7-bc8f-388717d18f96",
Expand Down
Loading