From 585ba82f3143d1f38ee4b36e3938d5865d7b98a4 Mon Sep 17 00:00:00 2001 From: hyqus Date: Wed, 20 Jan 2021 18:14:20 -0500 Subject: [PATCH 1/2] Add FX Ticker --- config/config.py | 1275 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 1246 insertions(+), 29 deletions(-) diff --git a/config/config.py b/config/config.py index 28b85e030..e9b8dd5a1 100644 --- a/config/config.py +++ b/config/config.py @@ -1,29 +1,1246 @@ -import pathlib - -#import finrl - -import pandas as pd -import datetime -import os -#pd.options.display.max_rows = 10 -#pd.options.display.max_columns = 10 - - -#PACKAGE_ROOT = pathlib.Path(finrl.__file__).resolve().parent -#PACKAGE_ROOT = pathlib.Path().resolve().parent - -#TRAINED_MODEL_DIR = PACKAGE_ROOT / "trained_models" -#DATASET_DIR = PACKAGE_ROOT / "data" - -# data -#TRAINING_DATA_FILE = "data/ETF_SPY_2009_2020.csv" -TRAINING_DATA_FILE = "data/dow_30_2009_2020.csv" - -now = datetime.datetime.now() -TRAINED_MODEL_DIR = f"trained_models/{now}" -os.makedirs(TRAINED_MODEL_DIR) -TURBULENCE_DATA = "data/dow30_turbulence_index.csv" - -TESTING_DATA_FILE = "test.csv" - - +import pathlib + +# import finrl + +import pandas as pd +import datetime +import os + +# pd.options.display.max_rows = 10 +# pd.options.display.max_columns = 10 + + +# PACKAGE_ROOT = pathlib.Path(finrl.__file__).resolve().parent +# PACKAGE_ROOT = pathlib.Path().resolve().parent + +TRAINED_MODEL_DIR = f"trained_models" +# DATASET_DIR = PACKAGE_ROOT / "data" + +# data +# TRAINING_DATA_FILE = "data/ETF_SPY_2009_2020.csv" +# TURBULENCE_DATA = "data/dow30_turbulence_index.csv" +# TESTING_DATA_FILE = "test.csv" + +# now = datetime.datetime.now() +# TRAINED_MODEL_DIR = f"trained_models/{now}" +DATA_SAVE_DIR = f"datasets" +TRAINED_MODEL_DIR = f"trained_models" +TENSORBOARD_LOG_DIR = f"tensorboard_log" +RESULTS_DIR = f"results" +# os.makedirs(TRAINED_MODEL_DIR) + + +## time_fmt = '%Y-%m-%d' +START_DATE = "2000-01-01" +END_DATE = "2021-01-01" + +START_TRADE_DATE = "2019-01-01" + +## dataset default columns +DEFAULT_DATA_COLUMNS = ["date", "tic", "close"] + +## stockstats technical indicator column names +## check https://pypi.org/project/stockstats/ for different names +TECHNICAL_INDICATORS_LIST = ["macd","boll_ub","boll_lb","rsi_30", "cci_30", "dx_30","close_30_sma","close_60_sma"] + + +## Model Parameters +A2C_PARAMS = {"n_steps": 5, "ent_coef": 0.01, "learning_rate": 0.0007} +PPO_PARAMS = { + "n_steps": 2048, + "ent_coef": 0.01, + "learning_rate": 0.00025, + "batch_size": 64, +} +DDPG_PARAMS = {"batch_size": 128, "buffer_size": 50000, "learning_rate": 0.001} +TD3_PARAMS = {"batch_size": 100, "buffer_size": 1000000, "learning_rate": 0.001} +SAC_PARAMS = { + "batch_size": 64, + "buffer_size": 100000, + "learning_rate": 0.0001, + "learning_starts": 100, + "batch_size": 64, + "ent_coef": "auto_0.1", +} + +######################################################## +############## Stock Ticker Setup starts ############## +SINGLE_TICKER = ["AAPL"] + +# self defined +SRI_KEHATI_TICKER = [ + "AALI.JK", + "ADHI.JK", + "ASII.JK", + "BBCA.JK", + "BBNI.JK", + "BBRI.JK", + "BBTN.JK", + "BMRI.JK", + "BSDE.JK", + "INDF.JK", + "JPFA.JK", + "JSMR.JK", + "KLBF.JK", + "PGAS.JK", + "PJAA.JK", + "PPRO.JK", + "SIDO.JK", + "SMGR.JK", + "TINS.JK", + "TLKM.JK", + "UNTR.JK", + "UNVR.JK", + "WIKA.JK", + "WSKT.JK", + "WTON.JK" +] + +# check https://wrds-www.wharton.upenn.edu/ for U.S. index constituents +# Dow 30 constituents at 2019/01 +DOW_30_TICKER = [ + "AAPL", + "MSFT", + "JPM", + "V", + "RTX", + "PG", + "GS", + "NKE", + "DIS", + "AXP", + "HD", + "INTC", + "WMT", + "IBM", + "MRK", + "UNH", + "KO", + "CAT", + "TRV", + "JNJ", + "CVX", + "MCD", + "VZ", + "CSCO", + "XOM", + "BA", + "MMM", + "PFE", + "WBA", + "DD", +] + +# Nasdaq 100 constituents at 2019/01 +NAS_100_TICKER = [ + "AMGN", + "AAPL", + "AMAT", + "INTC", + "PCAR", + "PAYX", + "MSFT", + "ADBE", + "CSCO", + "XLNX", + "QCOM", + "COST", + "SBUX", + "FISV", + "CTXS", + "INTU", + "AMZN", + "EBAY", + "BIIB", + "CHKP", + "GILD", + "NLOK", + "CMCSA", + "FAST", + "ADSK", + "CTSH", + "NVDA", + "GOOGL", + "ISRG", + "VRTX", + "HSIC", + "BIDU", + "ATVI", + "ADP", + "ROST", + "ORLY", + "CERN", + "BKNG", + "MYL", + "MU", + "DLTR", + "ALXN", + "SIRI", + "MNST", + "AVGO", + "TXN", + "MDLZ", + "FB", + "ADI", + "WDC", + "REGN", + "LBTYK", + "VRSK", + "NFLX", + "TSLA", + "CHTR", + "MAR", + "ILMN", + "LRCX", + "EA", + "AAL", + "WBA", + "KHC", + "BMRN", + "JD", + "SWKS", + "INCY", + "PYPL", + "CDW", + "FOXA", + "MXIM", + "TMUS", + "EXPE", + "TCOM", + "ULTA", + "CSX", + "NTES", + "MCHP", + "CTAS", + "KLAC", + "HAS", + "JBHT", + "IDXX", + "WYNN", + "MELI", + "ALGN", + "CDNS", + "WDAY", + "SNPS", + "ASML", + "TTWO", + "PEP", + "NXPI", + "XEL", + "AMD", + "NTAP", + "VRSN", + "LULU", + "WLTW", + "UAL", +] + +# SP 500 constituents at 2019 +SP_500_TICKER = [ + "A", + "AAL", + "AAP", + "AAPL", + "ABBV", + "ABC", + "ABMD", + "ABT", + "ACN", + "ADBE", + "ADI", + "ADM", + "ADP", + "ADS", + "ADSK", + "AEE", + "AEP", + "AES", + "AFL", + "AGN", + "AIG", + "AIV", + "AIZ", + "AJG", + "AKAM", + "ALB", + "ALGN", + "ALK", + "ALL", + "ALLE", + "ALXN", + "AMAT", + "AMCR", + "AMD", + "AME", + "AMG", + "AMGN", + "AMP", + "AMT", + "AMZN", + "ANET", + "ANSS", + "ANTM", + "AON", + "AOS", + "APA", + "APD", + "APH", + "APTV", + "ARE", + "ARNC", + "ATO", + "ATVI", + "AVB", + "AVGO", + "AVY", + "AWK", + "AXP", + "AZO", + "BA", + "BAC", + "BAX", + "BBT", + "BBY", + "BDX", + "BEN", + "BF.B", + "BHGE", + "BIIB", + "BK", + "BKNG", + "BLK", + "BLL", + "BMY", + "BR", + "BRK.B", + "BSX", + "BWA", + "BXP", + "C", + "CAG", + "CAH", + "CAT", + "CB", + "CBOE", + "CBRE", + "CBS", + "CCI", + "CCL", + "CDNS", + "CE", + "CELG", + "CERN", + "CF", + "CFG", + "CHD", + "CHRW", + "CHTR", + "CI", + "CINF", + "CL", + "CLX", + "CMA", + "CMCSA", + "CME", + "CMG", + "CMI", + "CMS", + "CNC", + "CNP", + "COF", + "COG", + "COO", + "COP", + "COST", + "COTY", + "CPB", + "CPRI", + "CPRT", + "CRM", + "CSCO", + "CSX", + "CTAS", + "CTL", + "CTSH", + "CTVA", + "CTXS", + "CVS", + "CVX", + "CXO", + "D", + "DAL", + "DD", + "DE", + "DFS", + "DG", + "DGX", + "DHI", + "DHR", + "DIS", + "DISCK", + "DISH", + "DLR", + "DLTR", + "DOV", + "DOW", + "DRE", + "DRI", + "DTE", + "DUK", + "DVA", + "DVN", + "DXC", + "EA", + "EBAY", + "ECL", + "ED", + "EFX", + "EIX", + "EL", + "EMN", + "EMR", + "EOG", + "EQIX", + "EQR", + "ES", + "ESS", + "ETFC", + "ETN", + "ETR", + "EVRG", + "EW", + "EXC", + "EXPD", + "EXPE", + "EXR", + "F", + "FANG", + "FAST", + "FB", + "FBHS", + "FCX", + "FDX", + "FE", + "FFIV", + "FIS", + "FISV", + "FITB", + "FLIR", + "FLS", + "FLT", + "FMC", + "FOXA", + "FRC", + "FRT", + "FTI", + "FTNT", + "FTV", + "GD", + "GE", + "GILD", + "GIS", + "GL", + "GLW", + "GM", + "GOOG", + "GPC", + "GPN", + "GPS", + "GRMN", + "GS", + "GWW", + "HAL", + "HAS", + "HBAN", + "HBI", + "HCA", + "HCP", + "HD", + "HES", + "HFC", + "HIG", + "HII", + "HLT", + "HOG", + "HOLX", + "HON", + "HP", + "HPE", + "HPQ", + "HRB", + "HRL", + "HSIC", + "HST", + "HSY", + "HUM", + "IBM", + "ICE", + "IDXX", + "IEX", + "IFF", + "ILMN", + "INCY", + "INFO", + "INTC", + "INTU", + "IP", + "IPG", + "IPGP", + "IQV", + "IR", + "IRM", + "ISRG", + "IT", + "ITW", + "IVZ", + "JBHT", + "JCI", + "JEC", + "JEF", + "JKHY", + "JNJ", + "JNPR", + "JPM", + "JWN", + "K", + "KEY", + "KEYS", + "KHC", + "KIM", + "KLAC", + "KMB", + "KMI", + "KMX", + "KO", + "KR", + "KSS", + "KSU", + "L", + "LB", + "LDOS", + "LEG", + "LEN", + "LH", + "LHX", + "LIN", + "LKQ", + "LLY", + "LMT", + "LNC", + "LNT", + "LOW", + "LRCX", + "LUV", + "LW", + "LYB", + "M", + "MA", + "MAA", + "MAC", + "MAR", + "MAS", + "MCD", + "MCHP", + "MCK", + "MCO", + "MDLZ", + "MDT", + "MET", + "MGM", + "MHK", + "MKC", + "MKTX", + "MLM", + "MMC", + "MMM", + "MNST", + "MO", + "MOS", + "MPC", + "MRK", + "MRO", + "MS", + "MSCI", + "MSFT", + "MSI", + "MTB", + "MTD", + "MU", + "MXIM", + "MYL", + "NBL", + "NCLH", + "NDAQ", + "NEE", + "NEM", + "NFLX", + "NI", + "NKE", + "NKTR", + "NLSN", + "NOC", + "NOV", + "NRG", + "NSC", + "NTAP", + "NTRS", + "NUE", + "NVDA", + "NWL", + "NWS", + "O", + "OI", + "OKE", + "OMC", + "ORCL", + "ORLY", + "OXY", + "PAYX", + "PBCT", + "PCAR", + "PEG", + "PEP", + "PFE", + "PFG", + "PG", + "PGR", + "PH", + "PHM", + "PKG", + "PKI", + "PLD", + "PM", + "PNC", + "PNR", + "PNW", + "PPG", + "PPL", + "PRGO", + "PRU", + "PSA", + "PSX", + "PVH", + "PWR", + "PXD", + "PYPL", + "QCOM", + "QRVO", + "RCL", + "RE", + "REG", + "REGN", + "RF", + "RHI", + "RJF", + "RL", + "RMD", + "ROK", + "ROL", + "ROP", + "ROST", + "RSG", + "RTN", + "SBAC", + "SBUX", + "SCHW", + "SEE", + "SHW", + "SIVB", + "SJM", + "SLB", + "SLG", + "SNA", + "SNPS", + "SO", + "SPG", + "SPGI", + "SRE", + "STI", + "STT", + "STX", + "STZ", + "SWK", + "SWKS", + "SYF", + "SYK", + "SYMC", + "SYY", + "T", + "TAP", + "TDG", + "TEL", + "TFX", + "TGT", + "TIF", + "TJX", + "TMO", + "TMUS", + "TPR", + "TRIP", + "TROW", + "TRV", + "TSCO", + "TSN", + "TSS", + "TTWO", + "TWTR", + "TXN", + "TXT", + "UA", + "UAL", + "UDR", + "UHS", + "ULTA", + "UNH", + "UNM", + "UNP", + "UPS", + "URI", + "USB", + "UTX", + "V", + "VAR", + "VFC", + "VIAB", + "VLO", + "VMC", + "VNO", + "VRSK", + "VRSN", + "VRTX", + "VTR", + "VZ", + "WAB", + "WAT", + "WBA", + "WCG", + "WDC", + "WEC", + "WELL", + "WFC", + "WHR", + "WLTW", + "WM", + "WMB", + "WMT", + "WRK", + "WU", + "WY", + "WYNN", + "XEC", + "XEL", + "XLNX", + "XOM", + "XRAY", + "XRX", + "XYL", + "YUM", + "ZBH", + "ZION", + "ZTS", +] + +# Hang Seng Index constituents at 2019/01 +HSI_50_TICKER = [ + "0011.HK", + "0005.HK", + "0012.HK", + "0006.HK", + "0003.HK", + "0016.HK", + "0019.HK", + "0002.HK", + "0001.HK", + "0267.HK", + "0101.HK", + "0941.HK", + "0762.HK", + "0066.HK", + "0883.HK", + "2388.HK", + "0017.HK", + "0083.HK", + "0939.HK", + "0388.HK", + "0386.HK", + "3988.HK", + "2628.HK", + "1398.HK", + "2318.HK", + "3328.HK", + "0688.HK", + "0857.HK", + "1088.HK", + "0700.HK", + "0836.HK", + "1109.HK", + "1044.HK", + "1299.HK", + "0151.HK", + "1928.HK", + "0027.HK", + "2319.HK", + "0823.HK", + "1113.HK", + "1038.HK", + "2018.HK", + "0175.HK", + "0288.HK", + "1997.HK", + "2007.HK", + "2382.HK", + "1093.HK", + "1177.HK", + "2313.HK", +] + +# www.csindex.com.cn, for SSE and CSI adjustments +# SSE 50 Index constituents at 2019 +SSE_50_TICKER = [ + "600000.SS", + "600036.SS", + "600104.SS", + "600030.SS", + "601628.SS", + "601166.SS", + "601318.SS", + "601328.SS", + "601088.SS", + "601857.SS", + "601601.SS", + "601668.SS", + "601288.SS", + "601818.SS", + "601989.SS", + "601398.SS", + "600048.SS", + "600028.SS", + "600050.SS", + "600519.SS", + "600016.SS", + "600887.SS", + "601688.SS", + "601186.SS", + "601988.SS", + "601211.SS", + "601336.SS", + "600309.SS", + "603993.SS", + "600690.SS", + "600276.SS", + "600703.SS", + "600585.SS", + "603259.SS", + "601888.SS", + "601138.SS", + "600196.SS", + "601766.SS", + "600340.SS", + "601390.SS", + "601939.SS", + "601111.SS", + "600029.SS", + "600019.SS", + "601229.SS", + "601800.SS", + "600547.SS", + "601006.SS", + "601360.SS", + "600606.SS", + "601319.SS", + "600837.SS", + "600031.SS", + "601066.SS", + "600009.SS", + "601236.SS", + "601012.SS", + "600745.SS", + "600588.SS", + "601658.SS", + "601816.SS", + "603160.SS", +] + +# CSI 300 Index constituents at 2019 +CSI_300_TICKER = [ + "600000.SS", + "600004.SS", + "600009.SS", + "600010.SS", + "600011.SS", + "600015.SS", + "600016.SS", + "600018.SS", + "600019.SS", + "600025.SS", + "600027.SS", + "600028.SS", + "600029.SS", + "600030.SS", + "600031.SS", + "600036.SS", + "600038.SS", + "600048.SS", + "600050.SS", + "600061.SS", + "600066.SS", + "600068.SS", + "600085.SS", + "600089.SS", + "600104.SS", + "600109.SS", + "600111.SS", + "600115.SS", + "600118.SS", + "600170.SS", + "600176.SS", + "600177.SS", + "600183.SS", + "600188.SS", + "600196.SS", + "600208.SS", + "600219.SS", + "600221.SS", + "600233.SS", + "600271.SS", + "600276.SS", + "600297.SS", + "600299.SS", + "600309.SS", + "600332.SS", + "600340.SS", + "600346.SS", + "600352.SS", + "600362.SS", + "600369.SS", + "600372.SS", + "600383.SS", + "600390.SS", + "600398.SS", + "600406.SS", + "600436.SS", + "600438.SS", + "600482.SS", + "600487.SS", + "600489.SS", + "600498.SS", + "600516.SS", + "600519.SS", + "600522.SS", + "600547.SS", + "600570.SS", + "600583.SS", + "600585.SS", + "600588.SS", + "600606.SS", + "600637.SS", + "600655.SS", + "600660.SS", + "600674.SS", + "600690.SS", + "600703.SS", + "600705.SS", + "600741.SS", + "600745.SS", + "600760.SS", + "600795.SS", + "600809.SS", + "600837.SS", + "600848.SS", + "600867.SS", + "600886.SS", + "600887.SS", + "600893.SS", + "600900.SS", + "600919.SS", + "600926.SS", + "600928.SS", + "600958.SS", + "600968.SS", + "600977.SS", + "600989.SS", + "600998.SS", + "600999.SS", + "601006.SS", + "601009.SS", + "601012.SS", + "601018.SS", + "601021.SS", + "601066.SS", + "601077.SS", + "601088.SS", + "601100.SS", + "601108.SS", + "601111.SS", + "601117.SS", + "601138.SS", + "601155.SS", + "601162.SS", + "601166.SS", + "601169.SS", + "601186.SS", + "601198.SS", + "601211.SS", + "601212.SS", + "601216.SS", + "601225.SS", + "601229.SS", + "601231.SS", + "601236.SS", + "601238.SS", + "601288.SS", + "601298.SS", + "601318.SS", + "601319.SS", + "601328.SS", + "601336.SS", + "601360.SS", + "601377.SS", + "601390.SS", + "601398.SS", + "601555.SS", + "601577.SS", + "601600.SS", + "601601.SS", + "601607.SS", + "601618.SS", + "601628.SS", + "601633.SS", + "601658.SS", + "601668.SS", + "601669.SS", + "601688.SS", + "601698.SS", + "601727.SS", + "601766.SS", + "601788.SS", + "601800.SS", + "601808.SS", + "601816.SS", + "601818.SS", + "601828.SS", + "601838.SS", + "601857.SS", + "601877.SS", + "601878.SS", + "601881.SS", + "601888.SS", + "601898.SS", + "601899.SS", + "601901.SS", + "601916.SS", + "601919.SS", + "601933.SS", + "601939.SS", + "601985.SS", + "601988.SS", + "601989.SS", + "601992.SS", + "601997.SS", + "601998.SS", + "603019.SS", + "603156.SS", + "603160.SS", + "603259.SS", + "603260.SS", + "603288.SS", + "603369.SS", + "603501.SS", + "603658.SS", + "603799.SS", + "603833.SS", + "603899.SS", + "603986.SS", + "603993.SS", + "000001.SZ", + "000002.SZ", + "000063.SZ", + "000066.SZ", + "000069.SZ", + "000100.SZ", + "000157.SZ", + "000166.SZ", + "000333.SZ", + "000338.SZ", + "000425.SZ", + "000538.SZ", + "000568.SZ", + "000596.SZ", + "000625.SZ", + "000627.SZ", + "000651.SZ", + "000656.SZ", + "000661.SZ", + "000671.SZ", + "000703.SZ", + "000708.SZ", + "000709.SZ", + "000723.SZ", + "000725.SZ", + "000728.SZ", + "000768.SZ", + "000776.SZ", + "000783.SZ", + "000786.SZ", + "000858.SZ", + "000860.SZ", + "000876.SZ", + "000895.SZ", + "000938.SZ", + "000961.SZ", + "000963.SZ", + "000977.SZ", + "001979.SZ", + "002001.SZ", + "002007.SZ", + "002008.SZ", + "002024.SZ", + "002027.SZ", + "002032.SZ", + "002044.SZ", + "002050.SZ", + "002120.SZ", + "002129.SZ", + "002142.SZ", + "002146.SZ", + "002153.SZ", + "002157.SZ", + "002179.SZ", + "002202.SZ", + "002230.SZ", + "002236.SZ", + "002241.SZ", + "002252.SZ", + "002271.SZ", + "002304.SZ", + "002311.SZ", + "002352.SZ", + "002371.SZ", + "002410.SZ", + "002415.SZ", + "002422.SZ", + "002456.SZ", + "002460.SZ", + "002463.SZ", + "002466.SZ", + "002468.SZ", + "002475.SZ", + "002493.SZ", + "002508.SZ", + "002555.SZ", + "002558.SZ", + "002594.SZ", + "002601.SZ", + "002602.SZ", + "002607.SZ", + "002624.SZ", + "002673.SZ", + "002714.SZ", + "002736.SZ", + "002739.SZ", + "002773.SZ", + "002841.SZ", + "002916.SZ", + "002938.SZ", + "002939.SZ", + "002945.SZ", + "002958.SZ", + "003816.SZ", + "300003.SZ", + "300014.SZ", + "300015.SZ", + "300033.SZ", + "300059.SZ", + "300122.SZ", + "300124.SZ", + "300136.SZ", + "300142.SZ", + "300144.SZ", + "300347.SZ", + "300408.SZ", + "300413.SZ", + "300433.SZ", + "300498.SZ", + "300601.SZ", + "300628.SZ", +] + +############## Stock Ticker Setup ends ############## + +###Jan 20,2020, added by YuQing Huang################### +################FX Ticker Setup Start################### +FX_TICKER = ["AUDCAD", + "AUDCHF", + "AUDJPY", + "AUDNZD", + "AUDSGD", + "AUDUSD", + "AUDUSD", + "AUDUSD", + "AUDUSD", + "AUDUSD", + "AUDUSD", + "AUDUSD", + "CADCHF", + "CADHKD", + "CADJPY", + "CHFJPY", + "CHFSGD", + "EURAUD", + "EURCAD", + "EURCHF", + "EURCHF", + "EURCHF", + "EURCZK", + "EURGBP", + "EURHKD", + "EURHUF", + "EURJPY", + "EURNOK", + "EURNZD", + "EURPLN", + "EURRUB", + "EURSEK", + "EURSGD", + "EURTRY", + "EURTRY", + "EURUSD", + "GBPAUD", + "GBPAUD", + "GBPAUD", + "GBPCAD", + "GBPCHF", + "GBPJPY", + "GBPNZD", + "GBPUSD", + "HKDJPY", + "NZDCAD", + "NZDCHF", + "NZDJPY", + "NZDUSD", + "SGDJPY", + "TRYJPY", + "USDCAD", + "USDCHF", + "USDCNH", + "USDCZK", + "USDHKD", + "USDHUF", + "USDILS", + "USDJPY", + "USDMXN", + "USDNOK", + "USDPLN", + "USDRON", + "USDRUB", + "USDSEK", + "USDSGD", + "USDTHB", + "USDTRY", + "USDZAR", + "XAGUSD", + "XAUUSD", + "ZARJPY", + "EURDKK" +] +################FX Ticker Setup End################### From 319eeed3c2608a24f367fc95d36e0099a41725c6 Mon Sep 17 00:00:00 2001 From: hyqus Date: Wed, 20 Jan 2021 19:07:48 -0500 Subject: [PATCH 2/2] actions are all positive. For sell, it should be negative to avoid mixing with buy actions are all positive. For sell, it should be negative to avoid mixing with buy --- env/env_stocktrading.py | 367 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 367 insertions(+) create mode 100644 env/env_stocktrading.py diff --git a/env/env_stocktrading.py b/env/env_stocktrading.py new file mode 100644 index 000000000..791afbc47 --- /dev/null +++ b/env/env_stocktrading.py @@ -0,0 +1,367 @@ +import os +import numpy as np +import pandas as pd +from gym.utils import seeding +import gym +from gym import spaces +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import pickle +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +from stable_baselines3.common import logger + + +class StockTradingEnv(gym.Env): + """A stock trading environment for OpenAI gym""" + metadata = {'render.modes': ['human']} + + def __init__(self, + df, + stock_dim, + hmax, + initial_amount, + buy_cost_pct, + sell_cost_pct, + reward_scaling, + state_space, + action_space, + tech_indicator_list, + turbulence_threshold=None, + make_plots = False, + print_verbosity = 10, + day = 0, + initial=True, + previous_state=[], + model_name = '', + mode='', + iteration=''): + self.day = day + self.df = df + self.stock_dim = stock_dim + self.hmax = hmax + self.initial_amount = initial_amount + self.buy_cost_pct = buy_cost_pct + self.sell_cost_pct = sell_cost_pct + self.reward_scaling = reward_scaling + self.state_space = state_space + self.action_space = action_space + self.tech_indicator_list = tech_indicator_list + self.action_space = spaces.Box(low = -1, high = 1,shape = (self.action_space,)) + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape = (self.state_space,)) + self.data = self.df.loc[self.day,:] + self.terminal = False + self.make_plots = make_plots + self.print_verbosity = print_verbosity + self.turbulence_threshold = turbulence_threshold + self.initial = initial + self.previous_state = previous_state + self.model_name=model_name + self.mode=mode + self.iteration=iteration + # initalize state + self.state = self._initiate_state() + + # initialize reward + self.reward = 0 + self.turbulence = 0 + self.cost = 0 + self.trades = 0 + self.episode = 0 + # memorize all the total balance change + self.asset_memory = [self.initial_amount] + self.rewards_memory = [] + self.actions_memory=[] + self.date_memory=[self._get_date()] + #self.reset() + self._seed() + + + + def _sell_stock(self, index, action): + def _do_sell_normal(): + # perform sell action based on the sign of the action + if self.state[index+self.stock_dim+1] > 0: + sell_num_shares = min(abs(action),self.state[index+self.stock_dim+1]) + sell_amount = self.state[index+1]* sell_num_shares * (1- self.sell_cost_pct) + #update balance + self.state[0] += sell_amount + + self.state[index+self.stock_dim+1] -= min(abs(action), self.state[index+self.stock_dim+1]) + self.cost +=self.state[index+1]*min(abs(action),self.state[index+self.stock_dim+1]) * \ + self.sell_cost_pct + self.trades+=1 + else: + sell_num_shares = 0 + pass + return sell_num_shares + + # perform sell action based on the sign of the action + if self.turbulence_threshold is not None: + if self.turbulence>=self.turbulence_threshold: + # if turbulence goes over threshold, just clear out all positions + if self.state[index+self.stock_dim+1] > 0: + sell_num_shares = self.state[index+self.stock_dim+1] + sell_amount = self.state[index+1]*sell_num_shares* (1- self.sell_cost_pct) + #update balance + self.state[0] += sell_amount + self.state[index+self.stock_dim+1] =0 + self.cost += self.state[index+1]*self.state[index+self.stock_dim+1]* \ + self.sell_cost_pct + self.trades+=1 + else: + sell_num_shares = 0 + pass + else: + sell_num_shares = _do_sell_normal() + else: + sell_num_shares = _do_sell_normal() + + return sell_num_shares + + + def _buy_stock(self, index, action): + + def _do_buy(): + available_amount = self.state[0] // self.state[index+1] + # print('available_amount:{}'.format(available_amount)) + + #update balance + buy_num_shares = min(available_amount, action) + buy_amount = self.state[index+1]* buy_num_shares * (1+ self.buy_cost_pct) + self.state[0] -= buy_amount + + self.state[index+self.stock_dim+1] += min(available_amount, action) + + self.cost+=self.state[index+1]*min(available_amount, action)* \ + self.buy_cost_pct + self.trades+=1 + + return buy_num_shares + # perform buy action based on the sign of the action + if self.turbulence_threshold is None: + buy_num_shares = _do_buy() + else: + if self.turbulence< self.turbulence_threshold: + buy_num_shares = _do_buy() + else: + pass + + return buy_num_shares + + def _make_plot(self): + plt.plot(self.asset_memory,'r') + plt.savefig('results/account_value_trade_{}.png'.format(self.episode)) + plt.close() + + def step(self, actions): + self.terminal = self.day >= len(self.df.index.unique())-1 + if self.terminal: + # print(f"Episode: {self.episode}") + if self.make_plots: + self._make_plot() + end_total_asset = self.state[0]+ \ + sum(np.array(self.state[1:(self.stock_dim+1)])*np.array(self.state[(self.stock_dim+1):(self.stock_dim*2+1)])) + df_total_value = pd.DataFrame(self.asset_memory) + tot_reward = self.state[0]+sum(np.array(self.state[1:(self.stock_dim+1)])*np.array(self.state[(self.stock_dim+1):(self.stock_dim*2+1)]))- self.initial_amount + df_total_value.columns = ['account_value'] + df_total_value['date'] = self.date_memory + df_total_value['daily_return']=df_total_value['account_value'].pct_change(1) + if df_total_value['daily_return'].std() !=0: + sharpe = (252**0.5)*df_total_value['daily_return'].mean()/ \ + df_total_value['daily_return'].std() + df_rewards = pd.DataFrame(self.rewards_memory) + df_rewards.columns = ['account_rewards'] + df_rewards['date'] = self.date_memory[:-1] + if self.episode % self.print_verbosity == 0: + print(f"day: {self.day}, episode: {self.episode}") + print(f"begin_total_asset: {self.asset_memory[0]:0.2f}") + print(f"end_total_asset: {end_total_asset:0.2f}") + print(f"total_reward: {tot_reward:0.2f}") + print(f"total_cost: {self.cost:0.2f}") + print(f"total_trades: {self.trades}") + if df_total_value['daily_return'].std() != 0: + print(f"Sharpe: {sharpe:0.3f}") + print("=================================") + + if (self.model_name!='') and (self.mode!=''): + df_actions = self.save_action_memory() + df_actions.to_csv('results/actions_{}_{}_{}.csv'.format(self.mode,self.model_name, self.iteration)) + df_total_value.to_csv('results/account_value_{}_{}_{}.csv'.format(self.mode,self.model_name, self.iteration),index=False) + df_rewards.to_csv('results/account_rewards_{}_{}_{}.csv'.format(self.mode,self.model_name, self.iteration),index=False) + plt.plot(self.asset_memory,'r') + plt.savefig('results/account_value_{}_{}_{}.png'.format(self.mode,self.model_name, self.iteration),index=False) + plt.close() + + # Add outputs to logger interface + logger.record("environment/portfolio_value", end_total_asset) + logger.record("environment/total_reward", tot_reward) + logger.record("environment/total_reward_pct", (tot_reward / (end_total_asset - tot_reward)) * 100) + logger.record("environment/total_cost", self.cost) + logger.record("environment/total_trades", self.trades) + + return self.state, self.reward, self.terminal, {} + + else: + + actions = actions * self.hmax #actions initially is scaled between 0 to 1 + actions = (actions.astype(int)) #convert into integer because we can't by fraction of shares + if self.turbulence_threshold is not None: + if self.turbulence>=self.turbulence_threshold: + actions=np.array([-self.hmax]*self.stock_dim) + begin_total_asset = self.state[0]+ \ + sum(np.array(self.state[1:(self.stock_dim+1)])*np.array(self.state[(self.stock_dim+1):(self.stock_dim*2+1)])) + #print("begin_total_asset:{}".format(begin_total_asset)) + + argsort_actions = np.argsort(actions) + + sell_index = argsort_actions[:np.where(actions < 0)[0].shape[0]] + buy_index = argsort_actions[::-1][:np.where(actions > 0)[0].shape[0]] + + for index in sell_index: + # print(f"Num shares before: {self.state[index+self.stock_dim+1]}") + # print(f'take sell action before : {actions[index]}') + actions[index] = self._sell_stock(index, actions[index])*(-1) + # print(f'take sell action after : {actions[index]}') + # print(f"Num shares after: {self.state[index+self.stock_dim+1]}") + + for index in buy_index: + # print('take buy action: {}'.format(actions[index])) + actions[index] = self._buy_stock(index, actions[index]) + + self.actions_memory.append(actions) + + self.day += 1 + self.data = self.df.loc[self.day,:] + if self.turbulence_threshold is not None: + self.turbulence = self.data['turbulence'].values[0] + self.state = self._update_state() + + end_total_asset = self.state[0]+ \ + sum(np.array(self.state[1:(self.stock_dim+1)])*np.array(self.state[(self.stock_dim+1):(self.stock_dim*2+1)])) + self.asset_memory.append(end_total_asset) + self.date_memory.append(self._get_date()) + self.reward = end_total_asset - begin_total_asset + self.rewards_memory.append(self.reward) + self.reward = self.reward*self.reward_scaling + + return self.state, self.reward, self.terminal, {} + + def reset(self): + if self.initial: + self.asset_memory = [self.initial_amount] + else: + previous_total_asset = self.previous_state[0]+ \ + sum(np.array(self.previous_state[1:(self.stock_dim+1)])*np.array(self.previous_state[(self.stock_dim+1):(self.stock_dim*2+1)])) + self.asset_memory = [previous_total_asset] + + self.day = 0 + self.data = self.df.loc[self.day,:] + self.turbulence = 0 + self.cost = 0 + self.trades = 0 + self.terminal = False + # self.iteration=self.iteration + self.rewards_memory = [] + self.actions_memory=[] + self.date_memory=[self._get_date()] + #initiate state + self.state = self._initiate_state() + self.episode+=1 + + return self.state + + def render(self, mode='human',close=False): + return self.state + + def _initiate_state(self): + if self.initial: + # For Initial State + if len(self.df.tic.unique())>1: + # for multiple stock + state = [self.initial_amount] + \ + self.data.close.values.tolist() + \ + [0]*self.stock_dim + \ + sum([self.data[tech].values.tolist() for tech in self.tech_indicator_list ], []) + else: + # for single stock + state = [self.initial_amount] + \ + [self.data.close] + \ + [0]*self.stock_dim + \ + sum([[self.data[tech]] for tech in self.tech_indicator_list ], []) + else: + #Using Previous State + if len(self.df.tic.unique())>1: + # for multiple stock + state = [self.previous_state[0]] + \ + self.data.close.values.tolist() + \ + self.previous_state[(self.stock_dim+1):(self.stock_dim*2+1)] + \ + sum([self.data[tech].values.tolist() for tech in self.tech_indicator_list ], []) + else: + # for single stock + state = [self.previous_state[0]] + \ + [self.data.close] + \ + self.previous_state[(self.stock_dim+1):(self.stock_dim*2+1)] + \ + sum([[self.data[tech]] for tech in self.tech_indicator_list ], []) + return state + + def _update_state(self): + if len(self.df.tic.unique())>1: + # for multiple stock + state = [self.state[0]] + \ + self.data.close.values.tolist() + \ + list(self.state[(self.stock_dim+1):(self.stock_dim*2+1)]) + \ + sum([self.data[tech].values.tolist() for tech in self.tech_indicator_list ], []) + + else: + # for single stock + state = [self.state[0]] + \ + [self.data.close] + \ + list(self.state[(self.stock_dim+1):(self.stock_dim*2+1)]) + \ + sum([[self.data[tech]] for tech in self.tech_indicator_list ], []) + + return state + + def _get_date(self): + if len(self.df.tic.unique())>1: + date = self.data.date.unique()[0] + else: + date = self.data.date + return date + + def save_asset_memory(self): + date_list = self.date_memory + asset_list = self.asset_memory + #print(len(date_list)) + #print(len(asset_list)) + df_account_value = pd.DataFrame({'date':date_list,'account_value':asset_list}) + return df_account_value + + def save_action_memory(self): + if len(self.df.tic.unique())>1: + # date and close price length must match actions length + date_list = self.date_memory[:-1] + df_date = pd.DataFrame(date_list) + df_date.columns = ['date'] + + action_list = self.actions_memory + df_actions = pd.DataFrame(action_list) + df_actions.columns = self.data.tic.values + df_actions.index = df_date.date + #df_actions = pd.DataFrame({'date':date_list,'actions':action_list}) + else: + date_list = self.date_memory[:-1] + action_list = self.actions_memory + df_actions = pd.DataFrame({'date':date_list,'actions':action_list}) + return df_actions + + def _seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] + + + def get_sb_env(self): + e = DummyVecEnv([lambda: self]) + obs = e.reset() + return e, obs +