Skip to content

Commit

Permalink
Update the GNN encoding for XGB financial example (NVIDIA#3039)
Browse files Browse the repository at this point in the history
* Readme notebook polish and cleanup

* Reorganize folder structure and initial gnn

* Complete the graph generate step with edgemap output

* Format fix

* Format fix

* Add graph construction and training notebooks

* Add full gnn functionality

* Update wording for readme

* update the GNN embedding usage

---------

Co-authored-by: Chester Chen <[email protected]>
Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
3 people committed Dec 10, 2024
1 parent 232823b commit bc46189
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 3 deletions.
95 changes: 95 additions & 0 deletions examples/advanced/finance-end-to-end/utils/merge_feat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import os

import pandas as pd

files = ["train", "test"]

bic_to_bank = {
"ZHSZUS33": "Bank_1",
"SHSHKHH1": "Bank_2",
"YXRXGB22": "Bank_3",
"WPUWDEFF": "Bank_4",
"YMNYFRPP": "Bank_5",
"FBSFCHZH": "Bank_6",
"YSYCESMM": "Bank_7",
"ZNZZAU3M": "Bank_8",
"HCBHSGSG": "Bank_9",
"XITXUS33": "Bank_10",
}

original_columns = [
"UETR",
"Timestamp",
"Amount",
"trans_volume",
"total_amount",
"average_amount",
"hist_trans_volume",
"hist_total_amount",
"hist_average_amount",
"x2_y1",
"x3_y2",
]


def main():
args = define_parser()
root_path = args.input_dir
original_feat_postfix = "_normalized.csv"
embed_feat_postfix = "_embedding.csv"
out_feat_postfix = "_combined.csv"

for bic in bic_to_bank.keys():
print("Processing BIC: ", bic)
for file in files:
original_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + original_feat_postfix)
embed_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + embed_feat_postfix)
out_feat_file = os.path.join(root_path, bic + "_" + bic_to_bank[bic], file + out_feat_postfix)

# Load the original and embedding features
original_feat = pd.read_csv(original_feat_file)
embed_feat = pd.read_csv(embed_feat_file)

# Select the columns of the original features
original_feat = original_feat[original_columns]

# Combine the features, matching the rows by "UETR"
out_feat = pd.merge(original_feat, embed_feat, on="UETR")

# Save the combined features
out_feat.to_csv(out_feat_file, index=False)


def define_parser():
parser = argparse.ArgumentParser()

parser.add_argument(
"-i",
"--input_dir",
type=str,
nargs="?",
default="/tmp/nvflare/xgb/credit_card",
help="output directory, default to '/tmp/nvflare/xgb/credit_card'",
)

return parser.parse_args()


if __name__ == "__main__":
main()
70 changes: 67 additions & 3 deletions examples/advanced/finance-end-to-end/xgboost.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@
"Similarily, we define job scripts on server-side to trigger and coordinate running client-side scripts on each site: \n",
"\n",
"* [graph_construction_job.py](./nvflare/graph_construct_job.py)\n",
"* [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py)"
"* [gnn_train_encode_job.py](./nvflare/gnn_train_encode_job.py)\n",
"\n",
"The resulting GNN encodings will be merged with the normalized data for enhancing the feature."
]
},
{
Expand All @@ -144,7 +146,8 @@
"\n",
"To specify the controller and executor, we need to define a Job. You can find the job construction in\n",
"\n",
"* [xgb_job.py](./nvflare/xgb_job.py). \n",
"* [xgb_job.py](./nvflare/xgb_job.py)\n",
"* [xgb_job_embed.py](./nvflare/xgb_job_embed.py)\n",
"\n",
"Below is main part of the code\n",
"\n",
Expand Down Expand Up @@ -338,12 +341,51 @@
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "c201d73d-d0eb-4691-b4f6-f4b930168ef2",
"metadata": {},
"source": [
"### GNN Training and Encoding"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "337f64e8-072e-4f6f-8698-f513ecb47f47",
"metadata": {},
"outputs": [],
"source": [
"%cd nvflare\n",
"! python gnn_train_encode_job.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -p gnn_train_encode.py -a \"-i /tmp/nvflare/xgb/credit_card -o /tmp/nvflare/xgb/credit_card/\"\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "fb6484fc-e226-4b1b-bc79-afbae4d2b918",
"metadata": {},
"source": [
"### GNN Encoding Merge"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3c8adaf-0644-4534-85ac-d83454b9ac8d",
"metadata": {},
"outputs": [],
"source": [
"! python3 ./utils/merge_feat.py"
]
},
{
"cell_type": "markdown",
"id": "aae5236b-0f40-4b91-9fc2-4f2836b52537",
"metadata": {},
"source": [
"### Run XGBoost Job"
"### Run XGBoost Job\n",
"#### Without GNN embeddings"
]
},
{
Expand All @@ -360,6 +402,28 @@
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "89cc1faa-bc20-4c78-8bb2-19880e98723f",
"metadata": {},
"source": [
"#### With GNN embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8530b8f8-5877-4c86-b72d-e8adacbad35a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%cd nvflare\n",
"! python xgb_job_embed.py -c 'YSYCESMM_Bank_7' 'FBSFCHZH_Bank_6' 'YXRXGB22_Bank_3' 'XITXUS33_Bank_10' 'HCBHSGSG_Bank_9' 'YMNYFRPP_Bank_5' 'ZHSZUS33_Bank_1' 'ZNZZAU3M_Bank_8' 'SHSHKHH1_Bank_2' 'WPUWDEFF_Bank_4' -i /tmp/nvflare/xgb/credit_card -w /tmp/nvflare/workspace/xgb/credit_card_embed\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"id": "50a9090a-d50a-46d7-bc8f-388717d18f96",
Expand Down

0 comments on commit bc46189

Please sign in to comment.