From 057b675e3c8e987665069938be07c0a1120e7ecd Mon Sep 17 00:00:00 2001 From: Brett Elliot Date: Thu, 7 Nov 2024 20:30:01 -0500 Subject: [PATCH] use a trading calendar to get the start date that is length bars earlier then end_date --- lumibot/data_sources/tradier_data.py | 10 +++++++++- tests/test_bars.py | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/lumibot/data_sources/tradier_data.py b/lumibot/data_sources/tradier_data.py index 9cfb7c624..9ff9db89e 100644 --- a/lumibot/data_sources/tradier_data.py +++ b/lumibot/data_sources/tradier_data.py @@ -6,7 +6,7 @@ import pytz from lumibot.entities import Asset, Bars -from lumibot.tools.helpers import create_options_symbol, parse_timestep_qty_and_unit +from lumibot.tools.helpers import create_options_symbol, parse_timestep_qty_and_unit, get_trading_days from lumiwealth_tradier import Tradier from .data_source import DataSource @@ -202,6 +202,14 @@ def get_historical_prices( td, _ = self.convert_timestep_str_to_timedelta(timestep) start_date = end_date - (td * length) + if timestep == 'day' and timeshift is None: + # What we really want is the last n bars, not the bars from the last n days. + # get twice as many days as we need to ensure we get enough bars + tcal_start_date = end_date - (td * length * 2) + trading_days = get_trading_days(market='NYSE', start_date=tcal_start_date, end_date=end_date) + # Now, start_date is the length bars before the last trading day + start_date = trading_days.index[-length] + # Check what timestep we are using, different endpoints are required for different timesteps try: if parsed_timestep_unit == "minute": diff --git a/tests/test_bars.py b/tests/test_bars.py index 1e85e2555..78221e650 100644 --- a/tests/test_bars.py +++ b/tests/test_bars.py @@ -56,7 +56,7 @@ def setup_class(cls): df['expected_return'] = df['Adj Close'].pct_change() cls.expected_df = df - @pytest.mark.skip() + # @pytest.mark.skip() @pytest.mark.skipif(not ALPACA_CONFIG['API_KEY'], reason="This test requires an alpaca API key") @pytest.mark.skipif(ALPACA_CONFIG['API_KEY'] == '', reason="This test requires an alpaca API key") def test_alpaca_data_source_daily_bars(self): @@ -80,7 +80,7 @@ def test_alpaca_data_source_daily_bars(self): # check that there is no dividend column... This test will fail when dividends are added. We hope that's soon. assert "dividend" not in prices.df.columns - @pytest.mark.skip() + # @pytest.mark.skip() def test_yahoo_data_source_daily_bars(self): """ This tests that the yahoo data_source calculates adjusted returns for bars and that they @@ -130,7 +130,7 @@ def test_yahoo_data_source_daily_bars(self): rtol=0 ) - @pytest.mark.skip() + # @pytest.mark.skip() def test_pandas_data_source_daily_bars(self, pandas_data_fixture): """ This tests that the pandas data_source calculates adjusted returns for bars and that they @@ -181,7 +181,7 @@ def test_pandas_data_source_daily_bars(self, pandas_data_fixture): rtol=0 ) - @pytest.mark.skip() + # @pytest.mark.skip() @pytest.mark.skipif(POLYGON_API_KEY == '', reason="This test requires a Polygon.io API key") def test_polygon_data_source_daily_bars(self): """