diff --git a/finrl/meta/preprocessor/yahoodownloader.py b/finrl/meta/preprocessor/yahoodownloader.py index a337e947f..e4a2bd730 100644 --- a/finrl/meta/preprocessor/yahoodownloader.py +++ b/finrl/meta/preprocessor/yahoodownloader.py @@ -33,7 +33,7 @@ def __init__(self, start_date: str, end_date: str, ticker_list: list): self.end_date = end_date self.ticker_list = ticker_list - def fetch_data(self, proxy=None) -> pd.DataFrame: + def fetch_data(self, proxy=None, auto_adjust=False) -> pd.DataFrame: """Fetches data from Yahoo API Parameters ---------- @@ -49,7 +49,11 @@ def fetch_data(self, proxy=None) -> pd.DataFrame: num_failures = 0 for tic in self.ticker_list: temp_df = yf.download( - tic, start=self.start_date, end=self.end_date, proxy=proxy + tic, + start=self.start_date, + end=self.end_date, + proxy=proxy, + auto_adjust=auto_adjust, ) if temp_df.columns.nlevels != 1: temp_df.columns = temp_df.columns.droplevel(1) @@ -65,16 +69,20 @@ def fetch_data(self, proxy=None) -> pd.DataFrame: data_df = data_df.reset_index() try: # convert the column names to standardized names - data_df.columns = [ - "date", - "open", - "high", - "low", - "close", - "adjcp", - "volume", - "tic", - ] + data_df.rename( + columns={ + "Date": "date", + "Adj Close": "adjcp", + "Close": "close", + "High": "high", + "Low": "low", + "Volume": "volume", + "Open": "open", + "tic": "tic", + }, + inplace=True, + ) + # use adjusted close price instead of close price data_df["close"] = data_df["adjcp"] # drop the adjusted close price column